branch: master
mcts_search.py
6715 bytesRaw
from __future__ import annotations
from typing import List, Optional, Dict, cast
import numpy as np
np.set_printoptions(suppress=True)
import math, functools, time, random, statistics
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer, Device, CompileError
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program

class MCTSNode:
  def __init__(self, kernel:Kernel, parent=None):
    self.kernel:Kernel = kernel
    self.t = math.inf
    self.n = 0
    self.tm = math.inf
    self.i = -1
    self.parents: List[MCTSNode] = [parent] if parent is not None else []
    self.children: Optional[List[MCTSNode]] = None
    self.removed_children: List[MCTSNode] = []

def expand_node(node:MCTSNode):
  assert node.children is None
  node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]

def remove_node(node:MCTSNode):
  for parent in node.parents:
    assert parent.children is not None
    parent.children.remove(node)
    parent.removed_children.append(node)

C = math.sqrt(2)
TEMP = 0.5
def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
  if node.children is None or len(node.children) == 0: return node
  unexplored_children = []
  explored_children = []
  ucb_explored_children: List[float] = []
  for child in node.children:
    if child.n == 0: unexplored_children.append(child)
    else:
      ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
      if not math.isinf(ucb):
        explored_children.append(child)
        ucb_explored_children.append(ucb)
  if len(unexplored_children): return random.choice(unexplored_children)
  if not len(explored_children): return node
  # safe softmax
  ucb_exp = np.exp((np.array(ucb_explored_children)-max(ucb_explored_children))/TEMP)
  return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)

# this will expand/remove sometimes
def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
  if root.children is None: expand_node(root)
  while root.children:
    # tree traversal
    node = _sample_tree(root, best_tm)

    if node.children is not None and len(node.children) == 0:
      remove_node(node)
      continue

    # node expansion
    if node.n != 0:
      if node.children is None: expand_node(node)
      assert node.children is not None
      if len(node.children) == 0:
        remove_node(node)
        continue
      node = random.choice(node.children)
    return node
  return None

def backprop(bnode:MCTSNode, tm, strength=1.0):
  if bnode.t > tm: bnode.t = tm
  bnode.n += strength
  for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))

graph_mcts_cnt = 0
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
  global graph_mcts_cnt
  # TODO: copied from BEAM
  key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
  if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
    ret = lin.copy()
    for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
    return ret

  rawbufs = _ensure_buffer_alloc(rawbufs)
  var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
  dev = Device[lin.opts.device]
  root = MCTSNode(lin)

  st = time.perf_counter()
  best, best_idx, best_tm = lin, 0, math.inf
  seen_libs: Dict[bytes, MCTSNode] = {}
  seen_asts: Dict[bytes, MCTSNode] = {}
  compile_time, runtime_time = 0.0, 0.0
  for i in range(amt):
    node = sample_tree(root, best_tm)  # sample and expand
    if node is None: break  # finished the whole tree
    node.i = i  # when was node explored

    opt_ast = node.kernel.get_optimized_ast()
    if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
      # early check for same optimized AST hit
      remove_node(node)
      tm = sibling_node.t
    else:
      seen_asts[opt_ast.key] = node

      # lowering (50% of the time)
      p = node.kernel.to_program(name_override="test")

      # rollout
      tm1 = time.perf_counter()
      try:
        lib = dev.compiler.compile(p.src)
      except CompileError:
        # NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
        lib = None
      tm2 = time.perf_counter()
      if lib is None:
        tm = math.inf
      else:
        if (sibling_node:=seen_libs.get(lib, None)) is not None:
          # NOTE: these should all be caught by the AST check, need to canonicalize
          # remove this node, it's a duplicate
          remove_node(node)
          tm = sibling_node.t
        else:
          seen_libs[lib] = node
          try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
          except RuntimeError: tm = math.inf
          node.tm = tm
      tm3 = time.perf_counter()
      compile_time += tm2-tm1
      runtime_time += tm3-tm2

      # mock rollout
      #node.tm = tm = random.random() + 0.1

    if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
    et = time.perf_counter() - st
    if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us     best: {best_tm:12.2f} us @ {best_idx+1:4d}      {i+1:4d}/{amt:4d}  {int(round((i+1)/et)):4d}/s     {node.kernel.colored_shape()}\033[K", end="")  # noqa: E501

    # backprop
    backprop(node, tm)
  if DEBUG>=2: print()

  if getenv("MCTSGRAPH"):
    import networkx as nx
    import os
    GRAPHPATH = "/tmp/net"
    def save_graph(G, fn, opt=""):
      print("saving", G, f"to {fn}.svg")
      nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
      os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')

    G = nx.DiGraph()
    def add_node(node:MCTSNode):
      if node.n == 0: return
      for parent in node.parents: G.add_edge(parent, node)
      gopts = node.kernel.applied_opts
      edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT"
      G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
                 fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
      if node.children is not None:
        for child in node.children+node.removed_children: add_node(child)
    add_node(root)
    save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
    graph_mcts_cnt += 1

  if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
  return best