branch: master
test_const_folding.py
15003 bytesRaw
import unittest, itertools, math
from typing import Any
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import CI
from tinygrad.codegen.devectorizer import full_graph_rewrite
import numpy as np
from tinygrad.device import is_dtype_supported

def _check_ast_count(desired_count:int, t:Tensor):
  # NOTE: this has side effect because everything can be scheduled only once
  schedule = t.schedule()
  asts = [s for s in schedule if s.ast.op is Ops.SINK]
  assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"

class TestUnaryOpsConstFolding(unittest.TestCase):
  def test_all_consts_ops(self):
    _check_ast_count(0, Tensor.ones(4).exp())
    _check_ast_count(0, Tensor.ones(4).sqrt())
    _check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
    _check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))

  def test_cast(self):
    _check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
    _check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))

  @unittest.expectedFailure  # no two level fold at lazybuffer
  def test_neg_folding(self):
    _check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
    _check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
    _check_ast_count(0, Tensor([1, 2, 3]).neg().neg())

  def test_neg_realized_no_fold(self):
    x = Tensor.randn(32, 32)
    x = x.clip(0, 1).realize()
    _check_ast_count(1, x.neg())

class TestBinaryOpsConstFolding(unittest.TestCase):
  def test_add_literal_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
  def test_add_tensor_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
  def test_literal_zero_add(self):
    _check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
  def test_tensor_zero_add(self):
    _check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))

  def test_sub_literal_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
  def test_sub_tensor_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) - Tensor.zeros(4))

  def test_mul_literal_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
  def test_mul_tensor_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
  def test_literal_zero_mul(self):
    _check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
  def test_tensor_zero_mul(self):
    _check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))

  def test_mul_literal_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
  def test_mul_tensor_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
  def test_literal_one_mul(self):
    _check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
  def test_tensor_one_mul(self):
    _check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))

  def test_bool_tensor_mul_bool(self):
    _check_ast_count(0, Tensor([True, False]) * True)
    _check_ast_count(0, Tensor([True, False]) * False)
  def test_bool_mul_bool_tensor(self):
    _check_ast_count(0, True * Tensor([True, False]))
    _check_ast_count(0, False * Tensor([True, False]))

  def test_div_literal_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
  def test_div_tensor_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))

  def test_idiv_literal_one(self):
    _check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
  def test_idiv_tensor_one(self):
    _check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))

  def test_pow_literal_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
  def test_pow_tensor_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))

  def test_pow_literal_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
  def test_pow_tensor_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
  def test_literal_one_pow(self):
    _check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
  def test_tensor_one_pow(self):
    _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))

class TestBitcastConstFolding(unittest.TestCase):
  def test_scalar_bitcast(self):
    def t(cases: dict[DType, Any]):
      for (from_dt, from_v), (to_dt, to_v) in itertools.product(cases.items(), cases.items()):
        if not math.isnan(from_v):
          r = full_graph_rewrite(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0]
          self.assertEqual(r.op, Ops.CONST, msg:=f"{from_dt} -> {to_dt} ({from_v} -> {to_v})")
          self.assertEqual(r.dtype, to_dt, msg)
          np.testing.assert_equal(r.arg, to_v, msg)

    t({dtypes.int8: 0, dtypes.uint8: 0, dtypes.bool: False})
    t({dtypes.int8: 1, dtypes.uint8: 1, dtypes.bool: True})

    t({dtypes.int8:  -1, dtypes.uint8:  2**8-1})
    t({dtypes.int16: -1, dtypes.uint16: 2**16-1, dtypes.float16: float('nan')})
    t({dtypes.int32: -1, dtypes.uint32: 2**32-1, dtypes.float32: float('nan')})
    t({dtypes.int64: -1, dtypes.uint64: 2**64-1, dtypes.float64: float('nan')})

    t({dtypes.int8:  -2**7,  dtypes.uint8:  2**7})
    t({dtypes.int16: -2**15, dtypes.uint16: 2**15})
    t({dtypes.int32: -2**31, dtypes.uint32: 2**31})
    t({dtypes.int64: -2**63, dtypes.uint64: 2**63})

    t({dtypes.int16: 13496, dtypes.uint16: 13496, dtypes.float16: 0.294921875})
    t({dtypes.int32: 1050081145, dtypes.uint32: 1050081145, dtypes.float32: 0.29485681653022766})
    t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})

  def test_vec_bitcast(self):
    r = full_graph_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
    self.assertEqual(r.op, Ops.VECTORIZE)
    self.assertEqual(r.dtype, dtypes.uint32.vec(3))
    self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))

# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):
  def test_scalar_index(self):
    t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
    # TODO: fold these
    _check_ast_count(2, t[:,:,Tensor(1),:])
    _check_ast_count(2, t[:,:,Tensor(1)+2,:])
    _check_ast_count(2, t[:,:,Tensor(1),Tensor(0)])

  @unittest.expectedFailure
  def test_const_tensor_index(self):
    # TODO: implement const tensor folded indexing
    t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
    _check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
    _check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
    _check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])

class TestMovedConstFolding(unittest.TestCase):
  def test_add_shrunk_zero(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))

  def test_add_padded_zero(self):
    # TODO: it's 1 now, this might be possible to fold
    _check_ast_count(1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))

  def test_mul_shrunk_one(self):
    _check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))

  def test_add_padded_one(self):
    _check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))

  def test_cast_padded(self):
    # NOTE: this is folded due to CAST_BEFORE_VIEW
    if is_dtype_supported(dtypes.int16):
      _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
      np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
    if is_dtype_supported(dtypes.uint16):
      _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
      np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
    # not folded
    if is_dtype_supported(dtypes.int64):
      _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
      np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])

class TestReduceOpsConstFolding(unittest.TestCase):
  def test_const_sum(self):
    _check_ast_count(0, Tensor.ones(4, 5, 6).sum())
    np.testing.assert_equal(Tensor.ones(4, 5, 6).sum().numpy(), 4 * 5 * 6)
    _check_ast_count(0, Tensor.ones(4, 5, 6).sum(axis=0))
    np.testing.assert_equal(Tensor.ones(4, 5, 6).sum(axis=0).numpy(), np.full((5, 6), 4))
    _check_ast_count(0, Tensor(4).sum())
    np.testing.assert_equal(Tensor(4).sum().numpy(), 4)

  def test_padded_const_sum(self):
    _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
    np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)

    # NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0.
    _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
    np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)

  def test_bool_zero_max(self):
    _check_ast_count(0, Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)))
    np.testing.assert_equal(Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)).numpy(), False)

  def test_zero_size_ops(self):
    for reduceop in [lambda x:x.prod(), lambda x:x.sum()]: # lambda x:x.max() NOTE: numpy gives "reduction operation maximum which has no identity"
      _check_ast_count(0, reduceop(Tensor.empty(1, 0)))
      np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty(shape)))

  def test_zero_size_ops_view(self):
    for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
      _check_ast_count(0, reduceop(Tensor.empty(1, 0, 4).permute((1, 2, 0)).contiguous()))
      np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty((shape))))

  def test_zero_size_ops_realized(self):
    for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
      _check_ast_count(0, reduceop((Tensor.randn(0, 1)+1).realize()))
      np.testing.assert_equal(reduceop((Tensor.randn(shape:=(0, 1))+1).realize()).numpy(), reduceop(np.empty(shape)))

  def test_zero_size_realize_folded(self):
    # non contiguous folded output doesn't realize
    _check_ast_count(0, Tensor.empty(1, 0).sum())
    # contiguous folded const can still schedule
    a = Tensor.empty(1, 0).sum().contiguous()
    _check_ast_count(2, a+2)
    self.assertIs(a.lazydata.base.op, Ops.BUFFER)
    np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
    # otherwise we just fuse it
    _check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())
    np.testing.assert_equal((Tensor.empty(1, 0).sum()+2).numpy(), 2)

  def test_const_prod(self):
    _check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
    np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))
    _check_ast_count(0, Tensor.full((4, 5, 6), fill_value=2).prod(axis=0))
    np.testing.assert_equal(Tensor.full((4, 5, 6), fill_value=2).prod(axis=0).numpy(), np.full((5, 6), 2**4))
    _check_ast_count(0, Tensor(4).prod())
    np.testing.assert_equal(Tensor(4).prod().numpy(), 4)

  def test_const_max(self):
    _check_ast_count(0, Tensor.ones(4, 5, 6).max())
    np.testing.assert_equal(Tensor.ones(4, 5, 6).max().numpy(), 1)
    _check_ast_count(0, Tensor(4).max())
    np.testing.assert_equal(Tensor(4).max().numpy(), 4)

  def test_sum_output_dtype(self):
    # sum output dtype can be different from input
    for dt in dtypes.fields().values():
      if is_dtype_supported(dt):
        t = Tensor.ones(16, dtype=dt).reshape(4, 4)
        assert t.sum().dtype == t.contiguous().sum().dtype

@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestMultiConstFolding(unittest.TestCase):
  def test_multi_const_folding_literal(self):
    ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
    t = Tensor.arange(16).float().realize().to(ds)

    # non const folding case creates one ast on each shard
    _check_ast_count(4, t + 1)
    _check_ast_count(4, 1 + t)
    _check_ast_count(4, t * 2)
    _check_ast_count(4, 2 * t)

    # const folded
    _check_ast_count(0, t + 0)
    _check_ast_count(0, 0 + t)
    _check_ast_count(0, t * 0)
    _check_ast_count(0, 0 * t)
    _check_ast_count(0, t * 1)
    _check_ast_count(0, 1 * t)
    np.testing.assert_equal((t + 0).numpy(), np.arange(16))
    np.testing.assert_equal((t * 0).numpy(), [0] * 16)
    np.testing.assert_equal((t * 1).numpy(), np.arange(16))

    _check_ast_count(0, t ** 0)
    _check_ast_count(0, t ** 1)
    _check_ast_count(0, 1 ** t)

  # failing because multi calls .contiguous() on every single sharded uop
  @unittest.expectedFailure
  def test_multi_const_folding_tensor(self):
    ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
    t = Tensor.arange(16).float().realize().to(ds)
    zero = Tensor.zeros(16).realize().to(ds)
    one = Tensor.ones(16).realize().to(ds)

    # const folded
    _check_ast_count(0, t + zero)
    _check_ast_count(0, zero + t)
    _check_ast_count(0, t * zero)
    _check_ast_count(0, zero * t)
    _check_ast_count(0, t * one)
    _check_ast_count(0, one * t)
    np.testing.assert_equal((t + zero).numpy(), np.arange(16))
    np.testing.assert_equal((t * zero).numpy(), [0] * 16)
    np.testing.assert_equal((t * one).numpy(), np.arange(16))

  @unittest.expectedFailure
  def test_multi_todo_pow(self):
    ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
    t = Tensor.arange(16).float().realize().to(ds)
    zero = Tensor.zeros(16).realize().to(ds)
    one = Tensor.ones(16).realize().to(ds)

    # TODO: fix pow folding
    _check_ast_count(0, t ** zero)
    _check_ast_count(0, t ** one)
    _check_ast_count(0, one ** t)

class TestTautologicalCompare(unittest.TestCase):
  # without const folding, these would have triggered -Wtautological-compare in clang
  def test_lt_false(self):
    # bool < False is always false
    np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False])

  def test_true_lt(self):
    # True < bool is always false
    np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False])

  def test_truth_table(self):
    np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False)
    np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True)
    np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
    np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)

  def test_a_eq_a(self):
    # self eq is always true for int or bool
    a = Tensor([1, 2, 3])
    np.testing.assert_equal((a == a).numpy(), [True, True, True])

    # not true for nan
    a = Tensor([math.nan, 1.0, 2.0])
    np.testing.assert_equal((a == a).numpy(), [False, True, True])

  def test_a_ne_a(self):
    # self not eq is always false for int or bool
    a = Tensor([1, 2, 3])
    np.testing.assert_equal((a != a).numpy(), [False, False, False])

    # not true for nan
    a = Tensor([math.nan, 1.0, 2.0])
    np.testing.assert_equal((a != a).numpy(), [True, False, False])

if __name__ == '__main__':
  unittest.main()