branch: master
test_uop_vmin_vmax.py
9659 bytesRaw
import unittest, math
from tinygrad.ops import UOp, Ops
from tinygrad.dtype import dtypes

class TestVminVmaxProperties(unittest.TestCase):
  def test_vmin_vmax_constant(self):
    # vmin and vmax for a constant
    uop = UOp.const(dtypes.int32, 42)
    self.assertEqual(uop.vmin, 42)
    self.assertEqual(uop.vmax, 42)

  def test_vmin_vmax_cmpne(self):
    uop = UOp.const(dtypes.int32, 42)
    def test_bool(u, x):
      self.assertEqual(u.vmin, x)
      self.assertEqual(u.vmax, x)
    test_bool(uop != 42, False)
    test_bool(uop != 43, True)
    test_bool(uop != 41, True)

  def test_vmin_vmax_addition_with_variable(self):
    # vmin and vmax for addition with a variable
    x = UOp.variable('x', 10, 20)
    uop = x + 5
    self.assertEqual(uop.vmin, 15)
    self.assertEqual(uop.vmax, 25)

  def test_vmin_vmax_subtraction_with_variable(self):
    x = UOp.variable('x', 10, 20)
    uop = x - 5
    self.assertEqual(uop.vmin, 5)
    self.assertEqual(uop.vmax, 15)
    uop = 5 - x
    self.assertEqual(uop.vmin, -15)
    self.assertEqual(uop.vmax, -5)

  def test_vmin_vmax_and_with_variable(self):
    x = UOp.variable('x', 10, 20)
    uop = x & 5
    self.assertEqual(uop.vmin, 0)
    self.assertEqual(uop.vmax, 5)

    # this can be improved
    uop = x & 15
    self.assertEqual(uop.vmin, 0)
    self.assertEqual(uop.vmax, 15)

    # this can be improved
    uop = x & 32
    self.assertEqual(uop.vmin, 0)
    self.assertEqual(uop.vmax, 20)

  def test_vmin_vmax_multiplication_with_variable(self):
    # vmin and vmax for multiplication with a variable
    x = UOp.variable('x', -3, 4)
    uop = x * 2
    self.assertEqual(uop.vmin, -6)
    self.assertEqual(uop.vmax, 8)

  def test_vmin_vmax_variable_inside_special(self):
    uop = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10))))
    self.assertEqual(uop.vmin, 0)
    self.assertEqual(uop.vmax, 10)

  def test_vmin_vmax_multiplication_0_inf(self):
    # vmin and vmax for multiplication with a variable
    x = UOp.const(dtypes.float, 0.0)
    y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
    uop = x * y
    # TODO: these should be 0, but definitely should not be nan
    self.assertEqual(uop.vmin, -math.inf)
    self.assertEqual(uop.vmax, math.inf)

  def test_vmin_vmax_with_negative_multiplication(self):
    # vmin and vmax when multiplying by a negative number
    x = UOp.variable('x', 2, 5)
    uop = x * -3
    self.assertEqual(uop.vmin, -15)
    self.assertEqual(uop.vmax, -6)

  def test_vmin_vmax_with_negative_multiplication2(self):
    # vmin and vmax when multiplying by a negative number
    x = UOp.variable('x', -2, 5)
    uop = x * -3
    self.assertEqual(uop.vmin, -15)
    self.assertEqual(uop.vmax, 6)

  def test_vmin_vmax_nested_min_max(self):
    # vmin and vmax with nested min/max operations
    x = UOp.variable('x', 0, 10)
    uop = x.maximum(5).minimum(8)
    self.assertEqual(uop.vmin, 5)
    self.assertEqual(uop.vmax, 8)

  def test_vmin_vmax_where(self):
    x = UOp.variable('x', 0, 10)
    y = UOp.variable('y', 1, 11)
    z = UOp.variable('z', 2, 12)
    uop = (x<5).where(y, z)
    self.assertEqual(uop.vmin, 1)
    self.assertEqual(uop.vmax, 12)

  def test_vmin_vmax_shl(self):
    x = UOp.variable('x', 0, 10) << 5
    self.assertEqual(x.vmin, 0)
    self.assertEqual(x.vmax, 10 << 5)

  def test_vmin_vmax_shr(self):
    x = UOp.variable('x', 0, 10) >> 2
    self.assertEqual(x.vmin, 0)
    self.assertEqual(x.vmax, 10 >> 2)

class TestVminVmaxDivMod(unittest.TestCase):
  def test_vmin_vmax_division_positive(self):
    # vmin and vmax for division of a variable by a positive constant
    x = UOp.variable('x', 10, 20)
    uop = x // 2
    self.assertEqual(uop.vmin, 5)
    self.assertEqual(uop.vmax, 10)

  def test_vmin_vmax_division_negative(self):
    # vmin and vmax for division of a variable by a negative constant
    x = UOp.variable('x', 10, 20)
    uop = x // -2
    self.assertEqual(uop.vmin, -10)
    self.assertEqual(uop.vmax, -5)

  def test_vmin_vmax_mod_positive(self):
    # vmin and vmax for modulo of a variable by a positive constant
    x = UOp.variable('x', 10, 20)
    uop = x % 3
    self.assertEqual(uop.vmin, 0)
    self.assertEqual(uop.vmax, 2)

  @unittest.skip("broken")
  def test_vmin_vmax_mod_negative(self):
    # vmin and vmax for modulo of a variable by a negative constant
    x = UOp.variable('x', 10, 20)
    uop = x % -3
    self.assertEqual(uop.vmin, -2)
    self.assertEqual(uop.vmax, 0)

  def test_vmin_vmax_division_with_mixed_range(self):
    # vmin and vmax for division of a variable with a range crossing zero
    x = UOp.variable('x', -10, 10)
    uop = x // 3
    self.assertEqual(uop.vmin, -4)  # -10//3 = -4
    self.assertEqual(uop.vmax, 3)   # 10//3 = 3

  def test_vmin_vmax_mod_with_mixed_range(self):
    # vmin and vmax for modulo of a variable with a range crossing zero
    x = UOp.variable('x', -10, 10)
    uop = x % 4
    self.assertEqual(uop.vmin, -3)
    self.assertEqual(uop.vmax, 3)

class TestVminVmaxVConst(unittest.TestCase):
  def test_vmin_vmax_vconst_single_element(self):
    # vmin and vmax for a single-element vector constant
    uop = UOp.const(dtypes.int32.vec(1), (42,))
    self.assertEqual(uop.vmin, 42)
    self.assertEqual(uop.vmax, 42)

  def test_vmin_vmax_vconst_multiple_elements(self):
    # vmin and vmax for a multi-element vector constant
    uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
    self.assertEqual(uop.vmin, -5)
    self.assertEqual(uop.vmax, 20)

  def test_vmin_vmax_vconst_all_equal(self):
    # vmin and vmax for a vector where all elements are equal
    uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
    self.assertEqual(uop.vmin, 7)
    self.assertEqual(uop.vmax, 7)

  def test_vmin_vmax_vconst_with_negative_values(self):
    # vmin and vmax for a vector constant containing negative values
    uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
    self.assertEqual(uop.vmin, -20)
    self.assertEqual(uop.vmax, -5)

  def test_vmin_vmax_vconst_with_floats(self):
    # vmin and vmax for a vector constant of float values
    uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
    self.assertEqual(uop.vmin, -3.2)
    self.assertEqual(uop.vmax, 1.5)

  def test_vmin_vmax_vconst_with_bools(self):
    # vmin and vmax for a vector constant of bool values
    uop = UOp.const(dtypes.float32.vec(3), (True, False, False))
    # TODO: these return floats, not bool
    self.assertEqual(uop.vmin, 0.0)
    self.assertEqual(uop.vmax, 1.0)

class TestConstFactor(unittest.TestCase):
  def test_const_factor_constant(self):
    # const_factor for a constant
    uop = UOp.const(dtypes.int32, 42)
    self.assertEqual(uop.const_factor(), 42)

  def test_const_factor_addition(self):
    # const_factor for an addition of constants
    uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
    self.assertEqual(uop.const_factor(), 6)  # GCD(30, 12) = 6

  def test_const_factor_multiplication(self):
    # const_factor for a multiplication of constants
    uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
    self.assertEqual(uop.const_factor(), 5)  # For multiplication, it's one of the factors

  def test_const_factor_with_variable(self):
    # const_factor for an expression involving a variable
    x = UOp.variable('x', 10, 20)
    uop = x * 3
    self.assertEqual(uop.const_factor(), 3)

  def test_const_factor_division(self):
    # const_factor for an expression with division
    x = UOp.variable('x', 10, 20)
    uop = x // 4
    self.assertEqual(uop.const_factor(), 1)  # Division reduces the const_factor to 1

  def test_const_factor_multiplication_of_var_and_const(self):
    # const_factor for multiplication of a variable and a constant
    x = UOp.variable('x', 6, 18)
    uop = x * 4
    self.assertEqual(uop.const_factor(), 4)  # Constant factor 4

  @unittest.skip("broken")
  def test_const_factor_multiplication_of_consts_and_vars(self):
    # Multiplying constants and variables
    x = UOp.variable('x', 10, 20)
    uop = (x * 3) * 5
    self.assertEqual(uop.const_factor(), 15)  # Constant multipliers are combined (3 * 5 = 15)

class TestDivides(unittest.TestCase):
  def test_divides_constant_exact(self):
    # Divides a constant by an exact divisor
    uop = UOp.const(dtypes.int32, 42)
    result = uop.divides(7)
    self.assertIsNotNone(result)
    self.assertEqual(result.const_factor(), 6)  # 42 / 7 = 6

  def test_divides_constant_inexact(self):
    # Try to divide a constant by a non-exact divisor
    uop = UOp.const(dtypes.int32, 42)
    result = uop.divides(5)
    self.assertIsNone(result)  # 42 is not divisible by 5

  @unittest.skip("broken")
  def test_divides_variable_and_constant(self):
    # Multiplying a variable by a constant, then dividing by the same constant
    x = UOp.variable('x', 10, 20)
    uop = x * 6
    result = uop.divides(6)
    self.assertIsNotNone(result)
    self.assertEqual(result, x)  # (x * 6) / 6 = x

  def test_divides_complex_expression(self):
    # Dividing a more complex expression
    x = UOp.variable('x', 10, 20)
    uop = (x * 6) + 18
    result = uop.divides(6)
    self.assertIsNotNone(result)
    self.assertEqual(result.const_factor(), 1)  # (x + 3), const_factor is 1

  def test_divides_with_inexact_factors(self):
    # Multiplying by a constant but dividing by a non-exact divisor
    x = UOp.variable('x', 15, 45)
    uop = x * 4
    result = uop.divides(3)
    self.assertIsNone(result)  # Cannot divide by 3, since 4 is not divisible by 3

if __name__ == '__main__':
  unittest.main()