branch: master
external_test_llama3_ff.py
346 bytesRaw
#!/usr/bin/env python3
from tinygrad import Tensor, TinyJit, nn
from extra.models.llama import FeedForward

if __name__ == "__main__":
  model = FeedForward(4096, 14336)
  for x in nn.state.get_parameters(model): x.replace(x.half()).realize()
  jrun = TinyJit(model)
  for i in range(5):
    print(f"*** run {i}")
    jrun(Tensor.rand(1, 4096))