branch: master
test_attention.py
774 bytesRaw
import unittest
from tinygrad import Tensor, dtypes

# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
  def test_half_qkv_buffers(self):
    BS, seqlen, dim = 10, 4, 100
    q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
    k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
    v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
    attn = q.scaled_dot_product_attention(k, v)
    sched = attn.schedule()
    # attention has 5 kernels now
    self.assertEqual(len(sched), 5)
    softmax_inputs = sched[1:4]
    for si in softmax_inputs:
      assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"

if __name__ == '__main__':
  unittest.main()