branch: master
fuzz_linearizer.py
14777 bytesRaw
import random, traceback, ctypes, argparse, os
from typing import Any
import numpy as np
from collections import defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_lin

# We need to insert ioctl before opening devices.
if os.getenv("VALIDATE_HCQ", 0) != 0:
  try:
    import extra.nv_gpu_driver.nv_ioctl
    from tinygrad import Device
    _, _ = Device["NV"], Device["CUDA"]
  except Exception: pass

  try:
    import extra.qcom_gpu_driver.opencl_ioctl
    from tinygrad import Device
    _, _ = Device["QCOM"], Device["GPU"]
  except Exception: pass

from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UOp, Ops
from tinygrad.device import is_dtype_supported

def on_linearizer_will_run(): pass
def on_linearizer_did_run(): pass
def compare_states(x, y): return (True, "")

if getenv("VALIDATE_HCQ"):
  if Device.DEFAULT == "NV":
    print("VALIDATE_HCQ: Comparing NV to CUDA")
    import extra.nv_gpu_driver.nv_ioctl
    validate_device = Device["CUDA"]
    on_linearizer_will_run = extra.nv_gpu_driver.nv_ioctl.before_launch
    on_linearizer_did_run = extra.nv_gpu_driver.nv_ioctl.collect_last_launch_state
    compare_states = extra.nv_gpu_driver.nv_ioctl.compare_launch_state
  elif Device.DEFAULT == "QCOM":
    print("VALIDATE_HCQ: Comparing QCOM to GPU")
    import extra.qcom_gpu_driver.opencl_ioctl
    validate_device = Device["GPU"]
    on_linearizer_will_run = extra.qcom_gpu_driver.opencl_ioctl.before_launch
    on_linearizer_did_run = extra.qcom_gpu_driver.opencl_ioctl.collect_last_launch_state
    compare_states = extra.qcom_gpu_driver.opencl_ioctl.compare_launch_state
  else:
    print(colored("VALIDATE_HCQ options is ignored", 'red'))

def tuplize_uops(uops:list[UOp]) -> tuple:
  return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops])

device = Device[Device.DEFAULT]

def get_fuzz_rawbufs(lin):
  rawbufs = bufs_from_lin(lin)

  # Reallocate output buffer with additional area to detect out-of-bounds writes.
  RED_AREA_SIZE = 1024
  # setting output  # TODO: multi-output kernel
  rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE)
  # setting inputs
  with Context(DEBUG=0):
    for rawbuf in rawbufs[1:]:
      if dtypes.is_unsigned(rawbuf.dtype):
        data = np.random.randint(0, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
      elif dtypes.is_int(rawbuf.dtype):
        data = np.random.randint(-100, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
      elif rawbuf.dtype == dtypes.bool:
        data = np.random.choice([True, False], size=rawbuf.size)
      elif rawbuf.dtype == dtypes.half:
        data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
      else:
        data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
      rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.base.realized.as_buffer())
  return rawbufs

def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None):
  rawbuf = type(old_rawbuf)(force_device or old_rawbuf.device, old_rawbuf.size if size is None else size, old_rawbuf.dtype).allocate()
  if copy:
    with Context(DEBUG=0): rawbuf.copyin(old_rawbuf.as_buffer())
  elif zero:
    with Context(DEBUG=0):
      mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
      ctypes.memset(from_mv(mv), 0, len(mv))
      rawbuf.copyin(mv)
  return rawbuf

def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]: # (error msg, run state)
  if rawbufs is None: rawbufs = bufs_from_lin(lin)
  if var_vals is None: var_vals = {v: v.min for v in lin.vars}

  # TODO: images needs required_optimization
  try:
    prg = CompiledRunner(lin.to_program())
  except KeyboardInterrupt: raise
  except Exception:
    traceback.print_exc()
    return "COMPILE_ERROR", None

  if getenv("VALIDATE_HCQ"): on_linearizer_will_run()
  try:
    prg(rawbufs, var_vals, wait=True)
  except KeyboardInterrupt: raise
  except Exception:
    traceback.print_exc()
    return "EXEC_ERROR", None

  if getenv("VALIDATE_HCQ"): run_state = on_linearizer_did_run()
  else: run_state = None

  return "PASS", run_state

def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
  # TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
  has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs)

  # TODO: raise specific fuzzing errors instead of str, and propagate the error message
  try:
    if rawbufs is None:
      rawbufs = get_fuzz_rawbufs(lin)
    else:
      rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
  except KeyboardInterrupt: raise
  except BaseException:
    return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth, None)

  if var_vals is None:
    # TODO: handle symbolic max case
    var_vals = {v: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()}

  if ground_truth is None and not has_bf16:
    unoptimized = Kernel(lin.ast)
    unoptimized.required_optimizations()
    if run_linearizer(unoptimized, rawbufs, var_vals)[0] != "PASS":
      return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth, None)
    ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy()

  rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
  run_msg, run_state = run_linearizer(lin, rawbufs, var_vals)
  if run_msg != "PASS": return (run_msg, rawbufs, var_vals, ground_truth, run_state)

  try:
    if not has_bf16:
      result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype))
      np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
  except KeyboardInterrupt: raise
  except AssertionError as e:
    if DEBUG >= 2:
      print(f"COMPARE_ERROR details: {e}")
      if getenv("DEBUG_VALUES") > 0:
        mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
        mismatched_result = result[mismatch_indices]
        mismatched_ground_truth = ground_truth[mismatch_indices]
        for i, idx in enumerate(mismatch_indices[0]):
          print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
    return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth, run_state)

  return ("PASS", rawbufs, var_vals, ground_truth, run_state)

def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
  SEED = getenv("SEED", 42)
  random.seed(SEED)
  np.random.seed(SEED)
  print(lin.ast)
  print(lin.colored_shape())
  seen_uops = {}
  last_lins = [lin]
  failures:defaultdict[str, list[tuple[tuple[UOp, ...], list[Opt]]]] = defaultdict(list)
  rawbufs, var_vals, ground_truth, validate_rawbufs = None, None, None, None

  FUZZ_ALL_ACTIONS = getenv("FUZZ_ALL_ACTIONS", 0)
  FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
  FUZZ_IGNORE_SIMPLE_OPS = getenv("FUZZ_IGNORE_SIMPLE_OPS", 1)

  if FUZZ_MAX_SIZE > 0 and prod(lin.full_shape) > FUZZ_MAX_SIZE:
    print("skipping large kernel")
    return failures
  if FUZZ_IGNORE_SIMPLE_OPS and _is_simple(lin):
    print("skipping simple kernel")
    return failures

  test_depth = 1 if opts_list is not None else getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)
  for depth in range(test_depth):
    next_lins = []
    for lin in last_lins:
      if opts_list is None: actions = get_kernel_actions(lin, include_0=False)
      else:
        actions = {}
        for oi,opts in enumerate(opts_list):
          lin2 = lin.copy()
          for o in opts: lin2.apply_opt(o)
          actions[oi] = lin2

      if not actions: continue
      if depth == 0 and getenv("FUZZ_REQUIRE_TC", 0):
        tc_acts = {i: k for k in actions.values() if k.applied_opts[0].op == OptOps.TC}
        if len(tc_acts) == 0: return failures
        else: actions = tc_acts

      test_lins = list(actions.values())
      if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
      elif opts_list is None: test_lins = [random.choice(test_lins)]

      for test_lin in test_lins:
        if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")

        # stop if kernel uops repeat
        try: tuops = tuplize_uops(test_lin.linearize().uops)
        except KeyboardInterrupt: raise
        except BaseException as e:
          print(test_lin.ast)
          print(test_lin.applied_opts)
          print(e)
          failures["LINEARIZE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
          continue

        if tuops in seen_uops: continue
        seen_uops[tuops] = tuple(test_lin.applied_opts)

        if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape())

        (msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
        if state1 is not None and validate_device is not None:
          validate_lin = test_lin.copy()
          validate_lin.opts = validate_device.renderer
          if validate_rawbufs is None:
            validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.device) for x in rawbufs]
          (_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)

          if _msg != "PASS": failures[f"VALIDATE_DEV_{_msg}"].append((validate_lin.ast, validate_lin.applied_opts))

          ok, err_msg = compare_states(state1, state2)
          if not ok: failures["HCQ_COMPARE_FAILURE"].append((err_msg, test_lin.ast, test_lin.applied_opts, state1, state2))

        if msg != "PASS":
          print(test_lin.ast)
          print(test_lin.applied_opts)
          print(msg)
          failures[msg].append((test_lin.ast, test_lin.applied_opts))
          continue

        next_lins.append(test_lin)

    last_lins = next_lins
    if FUZZ_ALL_ACTIONS: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
  return failures

def _is_simple(lin: Kernel) -> bool:
  if len(lin.ast.src) > 1: return False
  ast:UOp = lin.ast.src[0]
  if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
  return False

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
  parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
  parser.add_argument("--beamreplay", type=str, default=None, help="replay asts and opts got from beam with CAPTURE_BEAM")
  parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
  parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
  parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
  parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
  args = parser.parse_args()

  opts_list = None
  if args.ast is not None:
    print("loaded AST from CLI")
    ast_strs = [args.ast]
  elif args.file is not None:
    print(f"loading ASTs from file '{args.file}'")
    with open(args.file, 'r') as file:
      ast_strs = file.readlines()
  elif args.beamreplay is not None:
    print(f"loading BEAM replay from file '{args.beamreplay}'")
    with open(args.beamreplay, 'r') as file: fdata = file.readlines()
    ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata if not x.startswith("#")], [x.split(' :: ')[1] for x in fdata if not x.startswith("#")]

    # dedup ast_strs and opts_list
    dct = defaultdict(list)
    for i in range(len(ast_strs)): dct[ast_strs[i]].append(eval(opts_list[i]))
    ast_strs_items = list(dct.keys())
    opts_list = [dct[c] for c in ast_strs_items]
  elif args.logfile is not None:
    print(f"loading ASTs from LOGKERNS file '{args.file}'")
    with open(args.logfile, 'r') as file:
      kern_strs = file.readlines()
      test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
      ast_strs = [f"{lin.ast}" for lin in test_lins]
  else:
    print("loading ASTs from world")
    ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)

  print(f"{len(ast_strs)=}")
  tested = 0
  failed_ids = []
  failures = defaultdict(list)
  seen_ast_strs = set()

  try:
    for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
      if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
      if getenv("FUZZ_IMAGEONLY") and "dtypes.image" not in ast: continue
      if "dtypes.image" in ast and Device.DEFAULT not in {"GPU", "QCOM"}: continue  # IMAGE is only for GPU
      if ast in seen_ast_strs: continue
      seen_ast_strs.add(ast)

      lin = ast_str_to_lin(ast)
      if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs):
        print("skipping kernel due to not supported dtype")
        continue

      with Timing(f"tested ast {i}: "):
        tested += 1
        fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol, opts_list=(opts_list[i] if opts_list else None))
        if fuzz_failures: failed_ids.append(i)
        for k, v in fuzz_failures.items():
          for f in v:
            failures[k].append(f)
  except KeyboardInterrupt: print(colored("STOPPING...", 'red'))

  for msg, errors in failures.items():
    for i, payload in enumerate(errors):
      print(f"{msg} {i} kernel: {payload}") # easier to use with output with verify_kernel.py

  print(f"{tested=}")
  if failures:
    print(f"{failed_ids=}")
    for msg, errors in failures.items():
      print(f"{msg}: {len(errors)}")
    if len(failed_ids) == args.expected_failures:
      print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
  if len(failed_ids) != args.expected_failures:
    print(colored(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}", "red"))
    # TODO: fix this
    # raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
  else:
    print(colored("all passed", "green"))