branch: master
run_models.py
6468 bytesRaw
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx import get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs

def get_config(root_path: Path):
  ret = {}
  for path in root_path.rglob("*config.json"):
    config = json.load(path.open())
    if isinstance(config, dict):
      ret.update(config)
  return ret

def run_huggingface_validate(onnx_model_path, config, rtol, atol):
  onnx_model = onnx.load(onnx_model_path)
  onnx_runner = OnnxRunner(onnx_model)
  inputs = get_example_inputs(onnx_runner.graph_inputs, config)
  validate(onnx_model_path, inputs, rtol=rtol, atol=atol)

def get_tolerances(file_name): # -> rtol, atol
  # TODO very high rtol atol
  if "fp16" in file_name: return 9e-2, 9e-2
  if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
  return 4e-3, 3e-2

def validate_repos(models:dict[str, tuple[Path, Path]]):
  print(f"** Validating {len(model_paths)} models **")
  for model_id, (root_path, relative_path) in models.items():
    print(f"validating model {model_id}")
    model_path = root_path / relative_path
    onnx_file_name = model_path.stem
    config = get_config(root_path)
    rtol, atol = get_tolerances(onnx_file_name)
    st = time.time()
    run_huggingface_validate(model_path, config, rtol, atol)
    et = time.time() - st
    print(f"passed, took {et:.2f}s")

def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
  ret = {}
  op_counter = collections.Counter()
  unsupported_ops = collections.defaultdict(set)
  supported_ops = get_onnx_ops()
  print(f"** Retrieving stats from {len(model_paths)} models **")
  for model_id, (root_path, relative_path) in models.items():
    print(f"examining {model_id}")
    model_path = root_path / relative_path
    onnx_runner = OnnxRunner(onnx.load(model_path))
    for node in onnx_runner.graph_nodes:
      op_counter[node.op] += 1
      if node.op not in supported_ops:
        unsupported_ops[node.op].add(model_id)
    del onnx_runner
  ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
  ret["op_counter"] = op_counter.most_common()
  return ret

def debug_run(model_path, truncate, config, rtol, atol):
  if truncate != -1:
    model = onnx.load(model_path)
    nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
    new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
    model.graph.ClearField("node")
    model.graph.node.extend(nodes_up_to_limit)
    model.graph.ClearField("output")
    model.graph.output.extend(new_output_values)
    with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
      onnx.save(model, tmp.name)
      run_huggingface_validate(tmp.name, config, rtol, atol)
  else:
    run_huggingface_validate(model_path, config, rtol, atol)

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
  parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
  parser.add_argument("--check_ops", action="store_true", default=False,
                      help="Check support for ONNX operations in models from the YAML file")
  parser.add_argument("--validate", action="store_true", default=False,
                      help="Validate correctness of models from the YAML file")
  parser.add_argument("--debug", type=str, default="",
                      help="""Validates without explicitly needing a YAML or models pre-installed.
                      provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
                      provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
                      """)
  parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
  args = parser.parse_args()

  if not (args.check_ops or args.validate or args.debug):
    parser.error("Please provide either --validate, --check_ops, or --debug.")
  if args.truncate != -1 and not args.debug:
    parser.error("--truncate and --debug should be used together for debugging")

  if args.check_ops or args.validate:
    with open(args.input, 'r') as f:
      data = yaml.safe_load(f)
      assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
      model_paths = {
        model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
        for model_id, repo in data["repositories"].items()
        for model in repo["files"]
        if model["file"].endswith(".onnx")
      }

    if args.check_ops:
      pprint.pprint(retrieve_op_stats(model_paths))

    if args.validate:
      validate_repos(model_paths)

  if args.debug:
    from huggingface_hub import snapshot_download
    download_dir = Path(__file__).parent / "models"
    path:list[str] = args.debug.split("/")
    if len(path) == 2:
      # repo id
      # validates all onnx models inside repo
      repo_id = "/".join(path)
      root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
      snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
      config = get_config(root_path)
      for onnx_model in root_path.rglob("*.onnx"):
        rtol, atol = get_tolerances(onnx_model.name)
        print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
        debug_run(onnx_model, -1, config, rtol, atol)
    else:
      # model id
      # only validate the specified onnx model
      onnx_model = path[-1]
      assert path[-1].endswith(".onnx")
      repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
      root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
      snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
      config = get_config(root_path)
      rtol, atol = get_tolerances(onnx_model)
      print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
      debug_run(root_path / relative_path, args.truncate, config, rtol, atol)