branch: master
test_tar.py
5224 bytesRaw
import unittest, tarfile, io, os, pathlib, tempfile
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract

class TestTarExtractFile(unittest.TestCase):
  def setUp(self):
    self.test_dir = tempfile.mkdtemp()
    self.test_files = {
      'file1.txt': b'Hello, World!',
      'file2.bin': b'\x00\x01\x02\x03\x04',
      'empty_file.txt': b''
    }
    self.tar_path = os.path.join(self.test_dir, 'test.tar')
    with tarfile.open(self.tar_path, 'w') as tar:
      for filename, content in self.test_files.items():
        file_path = os.path.join(self.test_dir, filename)
        with open(file_path, 'wb') as f:
          f.write(content)
        tar.add(file_path, arcname=filename)

    # Create invalid tar file
    self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
    with open(self.invalid_tar_path, 'wb') as f:
      f.write(b'This is not a valid tar file')

  def tearDown(self):
    for filename in self.test_files:
      os.remove(os.path.join(self.test_dir, filename))
    os.remove(self.tar_path)
    os.remove(self.invalid_tar_path)
    os.rmdir(self.test_dir)

  def test_tar_extract_returns_dict(self):
    result = tar_extract(self.tar_path)
    self.assertIsInstance(result, dict)

  def test_tar_extract_correct_keys(self):
    result = tar_extract(self.tar_path)
    self.assertEqual(set(result.keys()), set(self.test_files.keys()))

  def test_tar_extract_content_size(self):
    result = tar_extract(self.tar_path)
    for filename, content in self.test_files.items():
      self.assertEqual(len(result[filename]), len(content))

  def test_tar_extract_content_values(self):
    result = tar_extract(self.tar_path)
    for filename, content in self.test_files.items():
      np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))

  def test_tar_extract_empty_file(self):
    result = tar_extract(self.tar_path)
    self.assertEqual(len(result['empty_file.txt']), 0)

  def test_tar_extract_non_existent_file(self):
    with self.assertRaises(FileNotFoundError):
      tar_extract('non_existent_file.tar')

  def test_tar_extract_invalid_file(self):
    with self.assertRaises(tarfile.ReadError):
      tar_extract(self.invalid_tar_path)

class TestTarExtractPAX(unittest.TestCase):
  tar_format = tarfile.PAX_FORMAT
  max_link_len = 1000_000
  test_files = {
    'a/file1.txt': b'Hello, World!',
    'a/b/file2.bin': b'\x00\x01\x02\x03\x04',
    'empty_file.txt': b'',
    '512file': b'a' * 512,
    'long_file': b'some data' * 100,
    'very' * 15 + '/' + 'very' * 15 + '_long_filename.txt': b'Hello, World!!',
    'very' * 200 + '_long_filename.txt': b'Hello, World!!!',
  }

  def create_tar_tensor(self):
    fobj = io.BytesIO()
    test_dirs = set(os.path.dirname(k) for k in self.test_files.keys()).difference({ '' })
    with tarfile.open(fileobj=fobj, mode='w', format=self.tar_format) as tar:
      for dirname in test_dirs:
        dir_info = tarfile.TarInfo(name=dirname)
        dir_info.type = tarfile.DIRTYPE
        tar.addfile(dir_info)

      for filename, content in self.test_files.items():
        file_info = tarfile.TarInfo(name=filename)
        file_info.size = len(content)
        tar.addfile(file_info, io.BytesIO(content))

        if len(filename) < self.max_link_len:
          link_info = tarfile.TarInfo(name=filename + '.lnk')
          link_info.type = tarfile.SYMTYPE
          link_info.linkname = filename
          tar.addfile(link_info)
    return Tensor(fobj.getvalue())

  def test_tar_extract_returns_dict(self):
    result = tar_extract(self.create_tar_tensor())
    self.assertIsInstance(result, dict)

  def test_tar_extract_correct_keys(self):
    result = tar_extract(self.create_tar_tensor())
    self.assertEqual(set(result.keys()), set(self.test_files.keys()))

  def test_tar_extract_content_size(self):
    result = tar_extract(self.create_tar_tensor())
    for filename, content in self.test_files.items():
      self.assertEqual(len(result[filename]), len(content))

  def test_tar_extract_content_values(self):
    result = tar_extract(self.create_tar_tensor())
    for filename, content in self.test_files.items():
      np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))

  def test_tar_extract_empty_file(self):
    result = tar_extract(self.create_tar_tensor())
    self.assertEqual(len(result['empty_file.txt']), 0)

  def test_tar_extract_non_existent_file(self):
    with self.assertRaises(FileNotFoundError):
      tar_extract(Tensor(pathlib.Path('non_existent_file.tar')))

  def test_tar_extract_invalid_file(self):
    with self.assertRaises(tarfile.ReadError):
      tar_extract(Tensor(b'This is not a valid tar file'))

  def test_tar_extract_invalid_file_long(self):
    with self.assertRaises(tarfile.ReadError):
      tar_extract(Tensor(b'This is not a valid tar file'*100))

class TestTarExtractUSTAR(TestTarExtractPAX):
  tar_format = tarfile.USTAR_FORMAT
  max_link_len = 100
  test_files = {k: v for k, v in TestTarExtractPAX.test_files.items() if len(k) < 256}

class TestTarExtractGNU(TestTarExtractPAX):
  tar_format = tarfile.GNU_FORMAT

if __name__ == '__main__':
  unittest.main()