branch: master
test_pickle.py
5450 bytesRaw
import unittest, pickle, types
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.helpers import GlobalCounters, ContextVar, Context
from tinygrad.ops import PatternMatcher, UPat, UOp, Ops

class TestPickle(unittest.TestCase):
  def test_pickle_code_object(self):
    y = lambda x: x*2  # noqa: E731
    code_str = pickle.dumps(y.__code__)
    fxn = types.FunctionType(pickle.loads(code_str), globals())
    self.assertEqual(fxn(2), 4)

  def test_pickle_pattern_matcher(self):
    pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
    sink = UOp.const(dtypes.int, 2)
    tt = pm.rewrite(sink)
    pm_str = pickle.dumps(pm)
    pm2 = pickle.loads(pm_str)
    self.assertEqual(pm2.rewrite(sink).key, tt.key)

  def test_pickle_main_pattern_matcher(self):
    from tinygrad.codegen.devectorizer import sym
    pickle.dumps(sym)

  def test_pickle_realized_tensor(self):
    print("** init")
    t = Tensor.rand(10, 10).realize()
    st = pickle.dumps(t)
    t_values = t.numpy()
    del t # free buffers
    print("** post pickle")
    init = GlobalCounters.kernel_count
    t2:Tensor = pickle.loads(st)
    np.testing.assert_equal(t_values, t2.numpy())
    # expect at most one COPY kernel
    self.assertLessEqual(GlobalCounters.kernel_count-init, 1)

  def test_pickle_realized_tensor_alt(self):
    print("** init")
    t = Tensor.rand(10, 10).to("CPU").realize()
    st = pickle.dumps(t)
    t_values = t.numpy()
    del t # free buffers
    print("** post pickle")
    init = GlobalCounters.kernel_count
    t2:Tensor = pickle.loads(st)
    np.testing.assert_equal(t_values, t2.numpy())
    self.assertEqual(GlobalCounters.kernel_count-init, 0)

  def test_pickle_realized_tensor_alt2(self):
    print("** init")
    t = Tensor.rand(10, 10).to("CPU").realize()
    tensor_uop = t.lazydata
    assert tensor_uop.is_realized, f"expected {tensor_uop} to be realized"
    t_values = t.numpy()
    # pickle
    st = pickle.dumps(t)
    # free buffers
    del t
    del tensor_uop
    print("** post pickle")
    t2:Tensor = pickle.loads(st)
    assert t2.lazydata.is_realized, f"expected {t2.lazydata} to be realized"
    np.testing.assert_equal(t_values, t2.numpy())

  # NOTE: currently Buffer exists on the uop, not tensor
  def test_pickle_buffer_uop(self):
    t = Tensor.arange(4).realize()
    a = t.lazydata
    assert a.op is Ops.BUFFER
    self.assertIsNotNone(buffer:=a.realized)
    s = pickle.dumps(a)
    # free buffers
    del a
    del buffer
    a2:UOp = pickle.loads(s)
    self.assertListEqual(a2.realized.as_buffer().cast("I").tolist(), [0, 1, 2, 3])

  def test_pickle_unrealized_tensor(self):
    t = Tensor.ones(10, 10)
    st = pickle.dumps(t)
    t2:Tensor = pickle.loads(st)
    np.testing.assert_equal(t.numpy(), t2.numpy())

  def test_pickle_variable(self):
    v = Variable("i", 1, 20).bind(10)
    t1 = Tensor.ones(10, v).contiguous()
    t2 = Tensor.ones(10, v).contiguous()
    ret = (t1+t2).sum(1)
    st = pickle.dumps(ret)
    del ret
    vt2 = pickle.loads(st)
    np.testing.assert_equal(vt2.numpy(), 20)

  def test_pickle_buffer_view(self):
    t = Tensor.arange(10, device="CPU").contiguous().realize()
    vt = t[3:5].contiguous().realize()
    assert hasattr(vt.lazydata.buffer, 'base')
    ref_value = vt.tolist()
    st = pickle.dumps(vt)
    del t, vt
    vt2 = pickle.loads(st)
    assert hasattr(vt2.lazydata.buffer, 'base')
    assert ref_value == vt2.tolist()

  def test_pickle_numpy(self):
    t = Tensor(np.array([1,2,3,4.]), dtype=dtypes.float32)
    st = pickle.dumps(t)
    t2:Tensor = pickle.loads(st)
    np.testing.assert_equal(t.numpy(), t2.numpy())

  def test_pickle_jit(self):
    @TinyJit
    def add(a, b): return a.sum()+b+1
    for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
    st = pickle.dumps(add)
    del add

    add_fxn = pickle.loads(st)
    x = Tensor.ones(10, 10).contiguous().realize()
    y = Tensor.ones(10, 10).contiguous().realize()
    print("post jit")
    out = add_fxn(x, y)
    np.testing.assert_equal(out.numpy(), 102)

  def test_pickle_context_var(self):
    v = ContextVar("test_var", 0)
    with Context(test_var=1):
      vs = pickle.dumps(v)
    v2 = pickle.loads(vs)
    self.assertEqual(v2.value, 1)

  def test_pickle_schedule(self):
    a = Tensor([1,2])
    out = a + 2
    sched = out.schedule()
    pk = pickle.dumps(sched)
    sched_pk = pickle.loads(pk)
    self.assertEqual(sched_pk[-1].ast, sched[-1].ast)

  def test_pickle_renderer(self):
    from tinygrad.device import Device
    pk = pickle.dumps(Device.default.renderer)
    pickle.loads(pk)

class TestPickleJIT(unittest.TestCase):
  @classmethod
  def setUpClass(cls):
    @TinyJit
    def add(a, b): return a.sum()+b+1
    for _ in range(3): add(Tensor.rand(1000, 1000), Tensor.rand(1000, 1000))
    cls.st = pickle.dumps(add)
    del add

  def test_inspect(self):
    import io
    class FakeClass:
      def __init__(self, *args, **kwargs):
        print(self.module, self.name)
    class InspectUnpickler(pickle.Unpickler):
      def find_class(self, module, name): return type("SpecializedFakeClass", (FakeClass,), {"name": name, "module": module})
    InspectUnpickler(io.BytesIO(self.st)).load()

  @unittest.skip("we are still saving intermediate buffers")
  def test_size(self):
    # confirm no intermediate buffers are saved
    self.assertLess(len(self.st), 1_000_000)

if __name__ == '__main__':
  unittest.main()