branch: master
test_rewrite_tracked_childen.py
2590 bytesRaw
import unittest
from tinygrad import Tensor
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
from tinygrad.engine.schedule import sym

class TestRewriteTrackedChildren(unittest.TestCase):
  def test_children_in_context(self):
    def print_children(ctx:RewriteContext, sink:UOp):
      view_w_child = sink.src[0].src[0].src[0]
      assert view_w_child.op is Ops.VIEW
      assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
      ctx.update_children()
      assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
      # this is the 3
      assert len(ctx.children[sink.src[0].src[1]]) == 1
      assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
      # this is the 4
      assert len(ctx.children[sink.src[0].src[0]]) == 1
      assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
    rewrite = PatternMatcher([
      (UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
      (UPat(Ops.SINK, name="sink"), print_children)
    ])
    a = Tensor(2)
    b = Tensor(3)
    c = a + b
    sink = c.lazydata.sink()
    sink = graph_rewrite(sink, rewrite, track_children=True)

  def test_simple_child(self):
    rewrite = PatternMatcher([
      (UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
    ])
    a = Tensor(2)
    b = Tensor(3)
    c = a + b
    sink = c.lazydata
    view_w_child = a.lazydata.src[0]
    print([x().arg for x in view_w_child.children])
    print([x.arg for x in sink.get_children_map()[view_w_child]])
    self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
    # children can either be added to or removed from the map with graph_rewrite
    # added to is easy to detect, just hook the UOp constructor
    # when are children removed?
    #  * if a rewrite rule returns a UOp, the matched node is removed from the graph
    sink = graph_rewrite(sink, rewrite)
    print([x().arg for x in view_w_child.children])
    print([x.arg for x in sink.get_children_map()[view_w_child]])
    self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))

  @unittest.expectedFailure
  def test_child_after_parent_update(self):
    def print_children(ctx, r):
      ctx.update_children()
      print(ctx.children[r])
    extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
    a = Tensor.empty(3, 3)
    r = (a+0).sum()
    graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)

if __name__ == '__main__':
  unittest.main()