branch: master
mockgpu.py
3496 bytesRaw
import ctypes, ctypes.util, time, os, builtins, fcntl
from tinygrad.runtime.support.hcq import HWInterface
from test.mockgpu.nv.nvdriver import NVDriver
from test.mockgpu.amd.amddriver import AMDDriver
start = time.perf_counter()

# *** ioctl lib ***
libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
libc.mmap.restype = ctypes.c_void_p

drivers = [AMDDriver(), NVDriver()]
tracked_fds = {}

orignal_memoryview = builtins.memoryview
class TrackedMemoryView:
  def __init__(self, data, rcb, wcb):
    self.mv = orignal_memoryview(data)
    self.rcb, self.wcb = rcb, wcb

  def __getitem__(self, index):
    self.rcb(self.mv, index)
    return self.mv[index]

  def __setitem__(self, index, value):
    self.mv[index] = value
    self.wcb(self.mv, index)

  def cast(self, new_type, **kwargs):
    self.mv = self.mv.cast(new_type, **kwargs)
    return self

  @property
  def nbytes(self): return self.mv.nbytes
  def __len__(self): return len(self.mv)
  def __repr__(self): return repr(self.mv)

def _memoryview(cls, mem):
  if isinstance(mem, int) or isinstance(mem, ctypes.Array):
    addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
    for d in drivers:
      for st,en,rcb,wcb in d.tracked_addresses:
        if st <= addr <= en: return TrackedMemoryView(mem, rcb, wcb)
  return orignal_memoryview(mem)
builtins.memoryview = type("memoryview", (), {'__new__': _memoryview}) # type: ignore

def _open(path, flags):
  for d in drivers:
    for x in d.tracked_files:
      if path == x.path:
        virtfd = d.open(path, flags, 0o777, x)
        tracked_fds[virtfd.fd] = virtfd
        return virtfd.fd
  return os.open(path, flags, 0o777) if os.path.exists(path) else None

class MockHWInterface(HWInterface):
  def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
    self.path = path
    self.fd = fd or _open(path, flags)

  def __del__(self):
    if self.fd in tracked_fds:
      tracked_fds[self.fd].close(self.fd)
      tracked_fds.pop(self.fd)
    else: os.close(self.fd)

  def ioctl(self, request, arg):
    if self.fd in tracked_fds:
      return tracked_fds[self.fd].ioctl(self.fd, request, ctypes.addressof(arg))
    return fcntl.ioctl(self.fd, request, arg)

  def mmap(self, start, sz, prot, flags, offset):
    if self.fd in tracked_fds:
      return tracked_fds[self.fd].mmap(start, sz, prot, flags, self.fd, offset)
    return libc.mmap(start, sz, prot, flags, self.fd, offset)

  def read(self, size=None, binary=False, offset=None):
    if binary: raise NotImplementedError()
    if self.fd in tracked_fds:
      return tracked_fds[self.fd].read_contents(size)
    with open(self.fd, "rb" if binary else "r", closefd=False) as file:
      if file.tell() >= os.fstat(self.fd).st_size: file.seek(0)
      return file.read(size)

  def listdir(self):
    if self.fd in tracked_fds:
      return tracked_fds[self.fd].list_contents()
    return os.listdir(self.path)

  def write(self, content, binary=False, offset=None): raise NotImplementedError()
  def seek(self, offset):
    if self.fd in tracked_fds:
      tracked_fds[self.fd].seek(offset)
    else:
      os.lseek(self.fd, offset, os.SEEK_CUR)
  @staticmethod
  def exists(path): return _open(path, os.O_RDONLY) is not None
  @staticmethod
  def readlink(path): raise NotImplementedError()
  @staticmethod
  def eventfd(initval, flags=None): NotImplementedError()