branch: master
assembly_rdna.py
11038 bytesRaw
import yaml
from typing import Tuple, Set, Dict
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.kernel import Ops
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH

# ugh, is this really needed?
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()

boilerplate_start = """
.global _start
_start:
.rodata
.align 0x10
.global code.kd
.type code.kd,STT_OBJECT
.amdhsa_kernel code"""

code_start = """.end_amdhsa_kernel
.text
code:
"""

# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
# RDNA3 is actually a SIMD machine!
class RDNACodegen(AssemblyCodegen):
  supports_float4: bool = True
  supports_float4_alu: bool = True
  supports_load3: bool = True
  sin_is_sin2pi: bool = True
  no_div: bool = True

  def specialize(self, asm) -> Tuple[str, str]:
    args = []
    for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
    ins = []

    v_cnt = 3  # v[0:2] is local_xyz
    s_cnt = 5  # s[0:1] is the address, s[2:4] is global_xyz

    dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
    alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
           BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
           UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
           BinaryOps.CMPLT: "cmp_lt"}

    pend_regs:Set[Register] = set()
    rtor:Dict[Register, str] = {}
    def reg_in(x):
      nonlocal pend_regs
      #print("reg_in", x, rtor[x], pend_regs)
      if x in pend_regs:
        #print("clear")
        ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
        pend_regs.clear()
      return rtor[x]
    def reg_out(x):
      return rtor[x]
    for uop, out, vin, arg in asm:
      if uop == Ops.DEFINE_REGISTER:
        if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
          for i in range(arg[2]):
            # TODO: Re-use gaps created by this to avoid wasting registers
            align = int(arg[0][0].itemsize / 4)
            if arg[0][1]:
              s_cnt += s_cnt % align
              reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
              s_cnt += align
            else:
              v_cnt += v_cnt % align
              reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
              v_cnt += align
            rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name

            if arg[0][0] == dtypes.float.vec(4):
              for off in range(4):
                reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
                rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
        elif arg[0][0] == dtypes.bool:
          for i in range(arg[2]):
            reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
            rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
        else:
          raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
      elif uop == Ops.SPECIAL:
        if arg.startswith('buf'):
          i = int(arg[3:])
          ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
          pend_regs.add(out)
          for r in out.subregs(): pend_regs.add(r)
        elif arg.startswith('gid'):
          ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
          # the docs lied, this is actually y
          if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10")  # untested
          if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
          elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
          # get local size
          offset = len(args)*8
          args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
          ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
          ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
          pend_regs.clear()
          ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
          ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
      elif uop == Ops.CONST:
        if arg == float('inf'): arg = "0x7f800000"
        elif arg == float('-inf'): arg = "0xff800000"
        if out.dtype == dtypes.float.vec(4):
          for off in range(4):
            ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
        else:
          ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
      elif uop == Ops.ALU:
        if arg in [BinaryOps.CMPLT]:
          ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
        else:
          alu_arg = alu[arg]
          if arg == TernaryOps.MULACC and out == vin[2]:
            alu_arg = "fmac"
            vin = vin[0:2]
          if out.dtype == dtypes.float.vec(4):
            for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
              ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
          else:
            ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
      elif uop == Ops.LOAD:
        if out.scalar:
          # swap arg order
          ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
        else:
          ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
        pend_regs.add(out)
        for r in out.subregs(): pend_regs.add(r)
      elif uop == Ops.STORE:
        ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
      elif uop == Ops.LABEL:
        ins.append(f"{arg}:")
      elif uop == Ops.COND_BRANCH:
        ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
      elif uop == Ops.CAST:
        if vin[0].dtype == dtypes.bool:
          if out.dtype == dtypes.float32:
            ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
        else:
          raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
      else:
        raise NotImplementedError(uop)

    ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']

    # dual alu group
    seen = set()
    new_ins = []
    for i,tins in enumerate(ins):
      if tins in seen: continue
      if tins.startswith("v_fmac_f32"):
        for gins in reversed(ins[i+1:]):
          if gins in seen: continue
          if gins.startswith("v_fmac_f32"):
            r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
            r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
            if r0[0]%2 == r1[0]%2: continue
            if r0[1]%2 == r1[1]%2: continue
            if r0[2]%2 == r1[2]%2: continue
            new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
            seen.add(tins)
            seen.add(gins)
            break
      if tins not in seen:
        new_ins.append(tins)
    ins = new_ins

    return 'code', self.assemble(args, ins, v_cnt, s_cnt)

  def assemble(self, args, ins, v_cnt, s_cnt):
    kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
                   '.amdhsa_next_free_vgpr': v_cnt,   # this matters!
                   '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
                   '.amdhsa_next_free_sgpr': s_cnt,
                   '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
                   '.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
                   '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
                   '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
                   '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
                   '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
                   '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}

    metadata = {'amdhsa.kernels': [{'.args': args,
                  '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
                  '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
                  '.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
                  '.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
                  '.wavefront_size': 32}],
                'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}

    code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" +  code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
    obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
    asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
    return asm