branch: master
benchmark_onnx.py
1225 bytesRaw
import sys, onnx, time
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs, validate

def load_onnx_model(onnx_file):
  onnx_model = onnx.load(onnx_file)
  run_onnx = OnnxRunner(onnx_model)
  run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
  return run_onnx_jit, run_onnx.graph_inputs

if __name__ == "__main__":
  onnx_file = fetch(sys.argv[1])
  run_onnx_jit, input_specs = load_onnx_model(onnx_file)
  print("loaded model")

  for i in range(3):
    new_inputs = get_example_inputs(input_specs)
    GlobalCounters.reset()
    print(f"run {i}")
    run_onnx_jit(**new_inputs)

  # run 20 times
  for _ in range(20):
    new_inputs = get_example_inputs(input_specs)
    GlobalCounters.reset()
    st = time.perf_counter()
    out = run_onnx_jit(**new_inputs)
    mt = time.perf_counter()
    val = out.numpy()
    et = time.perf_counter()
    print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")

  if getenv("ORT"):
    validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
    print("model validated")