branch: master
process_replay.py
5004 bytesRaw
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.codegen.kernel import Kernel, Opt
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp

# *** process replay settings

# internal
PAGE_SIZE = getenv("PAGE_SIZE", 100)
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{VERSION}"
os.environ["RUN_PROCESS_REPLAY"] = "0"
os.environ["CAPTURE_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
MAX_LINES = 500
def trunc_log(x):
  if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"]
  logging.info("\n".join(lines))

# user config
ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in os.getenv("PR_TITLE", flag))
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
class ProcessReplayWarning(Warning): pass

# *** recreators

def recreate_sched(big_sink:UOp) -> list[UOp]:
  sched, _, __ = create_schedule_with_vars(big_sink)
  return [x.ast for x in sched]
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
  k = Kernel(ast, opts=opts)
  for opt in applied_opts: k.apply_opt(opt)
  # NOTE: replay with the captured renderer, not the one in master
  return k.opts.render(cast(list,k.to_program(name).uops))

# *** diff a "good" recreation against the generated version

def diff(offset:int, name:str, fxn:Callable) -> None:
  # TODO: add this assert back for schedule
  if ASSERT_DIFF and name != "schedule": warnings.filterwarnings("error", category=ProcessReplayWarning)
  if early_stop.is_set(): return None
  conn = db_connection()
  cur = conn.cursor()
  cur.execute(f"SELECT val FROM '{name}_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
  changed = 0
  for row in cur.fetchall():
    if changed > MAX_DIFF_PCT:
      warnings.warn(f"detected changes in over {MAX_DIFF_PCT}% of {name}s. skipping further diff generation.")
      early_stop.set()
      break
    # try unpickle
    try: args = pickle.loads(row[0])
    except Exception as e:
      changed += 1
      warnings.warn(f"FAILED TO UNPICKLE OBJECTS {e}", ProcessReplayWarning)
      continue
    # try recreate
    try:
      ctx_vars = {k:v.value for k,v in args[-2].items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value}
      with Context(**ctx_vars): good = fxn(*args[:-2])
      if good is None: continue
    except Exception as e:
      changed += 1
      warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning)
      if ctx_vars: logging.info(ctx_vars)
      for x in args[:-2]: trunc_log(x)
      continue
    # diff kernels
    try: assert str(args[-1]) == str(good)
    except AssertionError:
      changed += 1
      if ctx_vars: logging.info(ctx_vars)
      for x in args[:-2]: trunc_log(x)
      changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines()))
      logging.info("\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in changes))
      warnings.warn("PROCESS REPLAY DETECTED CHANGE", ProcessReplayWarning)
  conn.commit()
  cur.close()

# *** generic runner for executing fxn across all rows of a table in parallel

def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
  conn = db_connection()
  cur = conn.cursor()
  try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
  except sqlite3.OperationalError:
    warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
    return None
  conn.commit()
  cur.close()
  with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
    inputs = list(range(0, row_count, PAGE_SIZE))
    list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
    pool.close()
    pool.join()
    pool.terminate()

# *** main loop

if __name__ == "__main__":
  if SKIP_PROCESS_REPLAY:
    logging.info("skipping process replay.")
    exit(0)

  print(f"running process replay with {ASSERT_DIFF=}")
  for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
    logging.info(f"***** {name} diff")
    try: _pmap(name, fxn)
    except Exception as e:
      if ASSERT_DIFF: raise e
      logging.error(f"{name} diff err {e}")