branch: master
test_winograd.py
2928 bytesRaw
import unittest
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel

class TestWinograd(unittest.TestCase):
  def setUp(self):
    self.old = WINO.value
    WINO.value = 1
  def tearDown(self):
    WINO.value = self.old

  def test_speed(self):
    x = Tensor.empty(1,4,9,9)
    w = Tensor.empty(4,4,3,3)

    with Timing("running conv: "):
      out = Tensor.conv2d(x, w)

    with Timing("scheduling: "):
      sched = out.schedule()

    for i,s in enumerate(sched):
      if s.ast.op is not Ops.SINK: continue
      ops = s.ast.toposort
      with Timing(f"linearize {i} with {len(ops):4d} ops: "):
        l = Kernel(s.ast)
        l.hand_coded_optimizations()
        l.linearize()
      assert len(l.sts) <= 256  # just the current value to prevent regression
      if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
      for st in l.sts:
        assert len(st.views) <= 2, "too many views in winograd"
        if DEBUG >= 3:
          print(f"{len(st.views):3d} views")
          for v in st.views: print(v)

  def test_profile(self):
    x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
    with Profiling(enabled=not CI, sort='time'):
      out = Tensor.conv2d(x,w).realize()
    out.numpy()

  def test_four_kernels(self):
    x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
    GlobalCounters.reset()
    out = Tensor.conv2d(x,w).realize()
    assert GlobalCounters.kernel_count == 4
    out.numpy()

  @unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX")
  def test_counters(self):
    IC, OC, X, Y = 4,4,9,9
    #OC, IC, X, Y = 512, 256, 8, 8
    x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
    GlobalCounters.reset()
    Tensor.conv2d(x,w).realize()
    ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
    WINO.value = 0
    GlobalCounters.reset()
    Tensor.conv2d(x,w).realize()
    ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem

    ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
    print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
    print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
    self.assertLess(ops_ratio, 2.6)  # TODO: there's issues with factorization now
    self.assertLess(mem_ratio, 10)

  def test_dtype(self):
    IC, OC, X, Y = 4,4,9,9
    x,w = Tensor.empty(1,IC,Y,X), Tensor.empty(OC,IC,3,3)
    self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.default_float)

    x,w = Tensor.empty(1,IC,Y,X,dtype=dtypes.half), Tensor.empty(OC,IC,3,3,dtype=dtypes.half)
    self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.half)

if __name__ == '__main__':
  unittest.main(verbosity=2)