RDNA assembly backend ($1000 bounty) (#787)
* Revert "Revert "ops rdna"" This reverts commitpull/990/head0400315078
. * Revert "Revert "writing 2"" This reverts commit325a3bf2cf
. * no dump * 2x 2 * simple asm * local size * sub * lil work * support args != 3 * assembler work * generate that * ptx assembler * begin index renderer * max * ptx loops * gemms work * valid works * asm working a bit more * close * passing all ops tests * ptx is a codegen only, not a backend * ptx * float16 support * rdna goes here * install types * make amd disassemble * ansilen for pretty print * fix ptx log2/exp2 * assemblyinstruction * new asm * working gemm * fix cmp * more passing * mod * ptx works again * rdan3 add works * log exp * sin is sin 2pi * fix types * progress * loops work * rdna xyz * better addressing * cleanups * handle exception in early process * div support * rdna float4 * locals work * fix neg index * cast * smaller diff * yaml * import only if selected * fromimport * types * this all needs rewriting * a few more
parent
dca084f227
commit
ba56ee6020
|
@ -5,7 +5,10 @@ import multiprocessing
|
|||
def _early_exec_process(qin, qout):
|
||||
while True:
|
||||
path, inp = qin.get()
|
||||
qout.put(subprocess.check_output(path, input=inp))
|
||||
try:
|
||||
qout.put(subprocess.check_output(path, input=inp))
|
||||
except subprocess.CalledProcessError as e:
|
||||
qout.put(e)
|
||||
|
||||
def enable_early_exec():
|
||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
||||
|
@ -15,7 +18,9 @@ def enable_early_exec():
|
|||
p.start()
|
||||
def early_exec(x):
|
||||
qin.put(x)
|
||||
return qout.get()
|
||||
ret = qout.get()
|
||||
if isinstance(ret, Exception): raise ret
|
||||
else: return ret
|
||||
return early_exec
|
||||
|
||||
def proc(itermaker, q) -> None:
|
||||
|
|
|
@ -24,8 +24,10 @@ code = open(pathlib.Path(__file__).parent / "prog.s", "r").read()
|
|||
|
||||
gen = []
|
||||
FLOPS = 0
|
||||
for j in range(4):
|
||||
for i in range(0, 251, 6):
|
||||
#MAX_REG = 251
|
||||
MAX_REG = 32
|
||||
for j in range(1):
|
||||
for i in range(0, MAX_REG, 6):
|
||||
#gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
#FLOPS += 4
|
||||
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
|
||||
|
@ -48,9 +50,10 @@ print(colored("creating CLProgram", "green"))
|
|||
prg = CLProgram("code", asm, binary=True)
|
||||
|
||||
print(colored("running program", "green"))
|
||||
FLOPS *= 100000*1024*1024 # loop * global_size
|
||||
G = 256
|
||||
FLOPS *= 100000*G*G # loop * global_size
|
||||
for i in range(3):
|
||||
tm = prg([1024, 1024], [256, 1], buf, wait=True)
|
||||
tm = prg([G, G], [256, 1], buf, wait=True)
|
||||
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
||||
|
||||
print(colored("transferring buffer", "green"))
|
||||
|
|
3
setup.py
3
setup.py
|
@ -19,7 +19,7 @@ setup(name='tinygrad',
|
|||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl'],
|
||||
install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl', 'PyYAML'],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
|
@ -41,6 +41,7 @@ setup(name='tinygrad',
|
|||
"opencv-python",
|
||||
"tabulate",
|
||||
"safetensors",
|
||||
"types-PyYAML",
|
||||
],
|
||||
},
|
||||
include_package_data=True)
|
||||
|
|
|
@ -309,7 +309,7 @@ class TestOps(unittest.TestCase):
|
|||
def test_sum_full(self):
|
||||
helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum())
|
||||
def test_sum_small_full(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum)
|
||||
def test_sum_relu(self):
|
||||
helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu())
|
||||
def test_sum(self):
|
||||
|
@ -877,7 +877,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this")
|
||||
@unittest.skipIf(Device.DEFAULT in ["CUDA", "PTX"], "CUDA fails on this")
|
||||
def test_maxpool2d_unit_stride(self):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
|
||||
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
|
@ -7,13 +7,19 @@ import functools
|
|||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'}
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
|
||||
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
def __repr__(self): return self.nm
|
||||
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes._float4:
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
|
@ -23,6 +29,8 @@ class AssemblyInstruction(NamedTuple):
|
|||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyCodegen(Linearizer):
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
|
||||
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
|
||||
raise NotImplementedError("must be implemented")
|
||||
|
@ -34,24 +42,28 @@ class AssemblyCodegen(Linearizer):
|
|||
self.limit_global_dims(3) # all GPU asms have 3 (for now)
|
||||
self.linearize()
|
||||
|
||||
cnts:DefaultDict[DType, int] = defaultdict(int)
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
def newreg(tok, dtype=dtypes.float32):
|
||||
def newreg(tok, dtype=dtypes.float32, scalar=False):
|
||||
nonlocal cnts, tor
|
||||
tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype)
|
||||
cnts[dtype] += 1
|
||||
if isinstance(tok, Token): dtype = tok.dtype # this
|
||||
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes._float4:
|
||||
for off in range(4):
|
||||
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(b):
|
||||
key = ("num", b)
|
||||
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b))
|
||||
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return tor[key]
|
||||
|
||||
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op))
|
||||
if not isinstance(b, Register): b = render_numnode(b)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return tor[key]
|
||||
|
||||
def render_cast(a:Register, new_dtype:DType) -> Register:
|
||||
|
@ -72,25 +84,29 @@ class AssemblyCodegen(Linearizer):
|
|||
def addr_w_offset(args):
|
||||
idx = args.idx*self.bufs[args.i].dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode) and not self.supports_load3:
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if len(nums) > 0:
|
||||
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = nums[0]
|
||||
reg = idx.render(render_ops)
|
||||
if self.supports_load3:
|
||||
return tor[f"buf{args.i}"], reg
|
||||
if reg.scalar:
|
||||
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return tor[f"buf{args.i}"], reg, off
|
||||
else:
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, off
|
||||
return reg, None, off
|
||||
|
||||
ins = []
|
||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))]
|
||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
for uop,newvar,vin,args in self.uops:
|
||||
if uop == UOps.CONST and newvar is not None:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args))
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
|
@ -107,38 +123,48 @@ class AssemblyCodegen(Linearizer):
|
|||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0))
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
|
||||
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif uop == UOps.CAST and newvar is not None:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = newreg(newvar)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU and newvar is not None:
|
||||
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
|
||||
out = newreg(newvar) if newvar not in tor else tor[newvar]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
|
||||
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
||||
ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args))
|
||||
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.POW:
|
||||
# TODO: add UnaryOps.SQRT
|
||||
tmp = newreg((newvar, "exp_a"))
|
||||
tmp2 = newreg((newvar, "exp_a_times_b"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2))
|
||||
elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'):
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2))
|
||||
elif args == BinaryOps.DIV and self.no_div:
|
||||
tmp = newreg((newvar, "rcp"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
|
||||
tmp = newreg((newvar, "2pi"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
idx, off = addr_w_offset(args)
|
||||
reg = newreg(newvar)
|
||||
idx, treg, off = addr_w_offset(args)
|
||||
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
||||
if args.valid.min == 0:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
|
@ -146,16 +172,16 @@ class AssemblyCodegen(Linearizer):
|
|||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared')))
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, off = addr_w_offset(args)
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared')))
|
||||
idx, treg, off = addr_w_offset(args)
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
||||
|
||||
# define registers
|
||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins
|
||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in ins: print(tins)
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, FusedOps
|
||||
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
boilerplate_start = """
|
||||
.global _start
|
||||
_start:
|
||||
.rodata
|
||||
.align 0x10
|
||||
.global code.kd
|
||||
.type code.kd,STT_OBJECT
|
||||
.amdhsa_kernel code"""
|
||||
|
||||
code_start = """.end_amdhsa_kernel
|
||||
.text
|
||||
code:
|
||||
"""
|
||||
|
||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
class RDNACodegen(AssemblyCodegen):
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = False
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", FusedOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPEQ: "cmp_eq", BinaryOps.CMPLT: "cmp_lt"}
|
||||
|
||||
pend_regs:Set[Register] = set()
|
||||
rtor:Dict[Register, str] = {}
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
#print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
#print("clear")
|
||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] == dtypes.uint64 and arg[0][1]:
|
||||
# assuming these are scalar
|
||||
s_cnt += s_cnt%2 # aligned(2)
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s[{s_cnt}:{s_cnt+1}]"
|
||||
s_cnt += 2
|
||||
elif arg[0][0] == dtypes._float4 and not arg[0][1]:
|
||||
v_cnt += (4-v_cnt%4) if v_cnt%4 != 0 else 0
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v[{v_cnt}:{v_cnt+3}]"
|
||||
for off in range(4): rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = f"v{v_cnt+off}"
|
||||
v_cnt += 4
|
||||
elif arg[0][0] in [dtypes.int32, dtypes.float32]:
|
||||
for i in range(arg[2]):
|
||||
if arg[0][1]:
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s{s_cnt}"
|
||||
s_cnt += 1
|
||||
else:
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
elif arg[0][0] == dtypes.bool and arg[0][1]:
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = "scc" if arg[0][1] else "vcc"
|
||||
else:
|
||||
raise NotImplementedError(arg)
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args)*8
|
||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == UOps.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes._float4:
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if out.scalar:
|
||||
ins.append(f"s_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
ins.append(f"v_cmp_lt_{dtype_to_rdnatype[out.dtype]} vcc, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == FusedOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes._float4:
|
||||
tins = []
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
|
||||
tins.append(f"{'s_' if rr[0].scalar else 'v_'}dual_{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
ins.append(tins[0] + " :: " + tins[1])
|
||||
ins.append(tins[2] + " :: " + tins[3])
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes._float4 else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes._float4 else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
|
||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
|
||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
||||
return asm
|
|
@ -172,8 +172,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
|
||||
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
if lang.double_prekernel and any(x.dtype == dtypes.float64 for x in bufs): prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
|
||||
class CStyleCodegen(Linearizer):
|
||||
|
|
|
@ -19,6 +19,7 @@ def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if n
|
|||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
||||
|
@ -74,7 +75,7 @@ class dtypes:
|
|||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float4)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
|
@ -87,10 +88,10 @@ class dtypes:
|
|||
float64: Final[DType] = DType(5, 8, "double", np.float64)
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(2, 8, "int64", np.int64)
|
||||
int64: Final[DType] = DType(2, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
|
||||
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
|
||||
uint64: Final[DType] = DType(2, 8, "ulong", np.uint64)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
|
|
|
@ -9,7 +9,8 @@ from tinygrad.runtime.lib import RawBuffer, RawConst
|
|||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
|
|
|
@ -4,11 +4,10 @@ import numpy as np
|
|||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.helpers import DEBUG, getenv, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
from tinygrad.codegen.assembly_ptx import PTXCodegen
|
||||
|
||||
class RawCUDABuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
|
||||
|
@ -60,4 +59,5 @@ class CUDACodegen(CStyleCodegen):
|
|||
typedef long long int64;
|
||||
""")
|
||||
supports_float4_alu = False
|
||||
CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
|
||||
CUDABuffer = Compiled(RawCUDABuffer, fromimport("tinygrad.codegen.assembly_ptx", "PTXCodegen") if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
|
|
|
@ -3,7 +3,7 @@ import pathlib
|
|||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
|
@ -61,7 +61,7 @@ class CLProgram:
|
|||
if 'Adreno' in CL.cl_ctx.devices[0].name:
|
||||
from disassemblers.adreno import disasm
|
||||
disasm(self.binary())
|
||||
elif 'gfx1100' in CL.cl_ctx.devices[0].name:
|
||||
elif CL.cl_ctx.devices[0].name.startswith('gfx'):
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
else:
|
||||
|
@ -87,11 +87,12 @@ class CLProgram:
|
|||
|
||||
class CLCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "#define int64 long\n__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
|
||||
supports_float4_alu = True
|
||||
supports_float4 = True
|
||||
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, fromimport("tinygrad.codegen.assembly_rdna", "RDNACodegen") if getenv("RDNA") else CLCodegen, CLProgram, CL.synchronize)
|
||||
|
|
|
@ -80,7 +80,7 @@ class MetalProgram:
|
|||
|
||||
class MetalCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>;\n#define int64 long\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4",
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
|
||||
|
|
Loading…
Reference in New Issue