1
0
Fork 0

RDNA assembly backend ($1000 bounty) (#787)

* Revert "Revert "ops rdna""

This reverts commit 0400315078.

* Revert "Revert "writing 2""

This reverts commit 325a3bf2cf.

* no dump

* 2x 2

* simple asm

* local size

* sub

* lil work

* support args != 3

* assembler work

* generate that

* ptx assembler

* begin index renderer

* max

* ptx loops

* gemms work

* valid works

* asm working a bit more

* close

* passing all ops tests

* ptx is a codegen only, not a backend

* ptx

* float16 support

* rdna goes here

* install types

* make amd disassemble

* ansilen for pretty print

* fix ptx log2/exp2

* assemblyinstruction

* new asm

* working gemm

* fix cmp

* more passing

* mod

* ptx works again

* rdan3 add works

* log exp

* sin is sin 2pi

* fix types

* progress

* loops work

* rdna xyz

* better addressing

* cleanups

* handle exception in early process

* div support

* rdna float4

* locals work

* fix neg index

* cast

* smaller diff

* yaml

* import only if selected

* fromimport

* types

* this all needs rewriting

* a few more
pull/990/head
George Hotz 2023-06-16 09:33:18 -07:00 committed by GitHub
parent dca084f227
commit ba56ee6020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 272 additions and 53 deletions

View File

@ -5,7 +5,10 @@ import multiprocessing
def _early_exec_process(qin, qout):
while True:
path, inp = qin.get()
qout.put(subprocess.check_output(path, input=inp))
try:
qout.put(subprocess.check_output(path, input=inp))
except subprocess.CalledProcessError as e:
qout.put(e)
def enable_early_exec():
qin: multiprocessing.Queue = multiprocessing.Queue()
@ -15,7 +18,9 @@ def enable_early_exec():
p.start()
def early_exec(x):
qin.put(x)
return qout.get()
ret = qout.get()
if isinstance(ret, Exception): raise ret
else: return ret
return early_exec
def proc(itermaker, q) -> None:

View File

@ -24,8 +24,10 @@ code = open(pathlib.Path(__file__).parent / "prog.s", "r").read()
gen = []
FLOPS = 0
for j in range(4):
for i in range(0, 251, 6):
#MAX_REG = 251
MAX_REG = 32
for j in range(1):
for i in range(0, MAX_REG, 6):
#gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
#FLOPS += 4
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
@ -48,9 +50,10 @@ print(colored("creating CLProgram", "green"))
prg = CLProgram("code", asm, binary=True)
print(colored("running program", "green"))
FLOPS *= 100000*1024*1024 # loop * global_size
G = 256
FLOPS *= 100000*G*G # loop * global_size
for i in range(3):
tm = prg([1024, 1024], [256, 1], buf, wait=True)
tm = prg([G, G], [256, 1], buf, wait=True)
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
print(colored("transferring buffer", "green"))

View File

@ -19,7 +19,7 @@ setup(name='tinygrad',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
],
install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl'],
install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl', 'PyYAML'],
python_requires='>=3.8',
extras_require={
'llvm': ["llvmlite"],
@ -41,6 +41,7 @@ setup(name='tinygrad',
"opencv-python",
"tabulate",
"safetensors",
"types-PyYAML",
],
},
include_package_data=True)

View File

@ -309,7 +309,7 @@ class TestOps(unittest.TestCase):
def test_sum_full(self):
helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum())
def test_sum_small_full(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum)
def test_sum_relu(self):
helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu())
def test_sum(self):
@ -877,7 +877,7 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
@unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this")
@unittest.skipIf(Device.DEFAULT in ["CUDA", "PTX"], "CUDA fails on this")
def test_maxpool2d_unit_stride(self):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1),

View File

@ -1,5 +1,5 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
@ -7,13 +7,19 @@ import functools
import math
from collections import defaultdict
type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'}
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
class Register(NamedTuple):
nm:str
dtype:DType
def __repr__(self): return self.nm
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes._float4:
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
@ -23,6 +29,8 @@ class AssemblyInstruction(NamedTuple):
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyCodegen(Linearizer):
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
raise NotImplementedError("must be implemented")
@ -34,24 +42,28 @@ class AssemblyCodegen(Linearizer):
self.limit_global_dims(3) # all GPU asms have 3 (for now)
self.linearize()
cnts:DefaultDict[DType, int] = defaultdict(int)
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
def newreg(tok, dtype=dtypes.float32):
def newreg(tok, dtype=dtypes.float32, scalar=False):
nonlocal cnts, tor
tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype)
cnts[dtype] += 1
if isinstance(tok, Token): dtype = tok.dtype # this
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
cnts[(dtype, scalar)] += 1
return ret
def render_numnode(b):
key = ("num", b)
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b))
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return tor[key]
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op))
if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return tor[key]
def render_cast(a:Register, new_dtype:DType) -> Register:
@ -72,25 +84,29 @@ class AssemblyCodegen(Linearizer):
def addr_w_offset(args):
idx = args.idx*self.bufs[args.i].dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode) and not self.supports_load3:
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0:
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = nums[0]
reg = idx.render(render_ops)
if self.supports_load3:
return tor[f"buf{args.i}"], reg
if reg.scalar:
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return tor[f"buf{args.i}"], reg, off
else:
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, off
return reg, None, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))]
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
global_size, local_size = [], []
skipload_branch = 0
for uop,newvar,vin,args in self.uops:
if uop == UOps.CONST and newvar is not None:
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args))
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
elif uop == UOps.DEFINE_LOCAL:
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
@ -107,38 +123,48 @@ class AssemblyCodegen(Linearizer):
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0))
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = newreg(newvar)
for i,sr in enumerate(out.subregs()):
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
out = newreg(newvar) if newvar not in tor else tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args))
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.POW:
# TODO: add UnaryOps.SQRT
tmp = newreg((newvar, "exp_a"))
tmp2 = newreg((newvar, "exp_a_times_b"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2))
elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'):
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2))
elif args == BinaryOps.DIV and self.no_div:
tmp = newreg((newvar, "rcp"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
tmp = newreg((newvar, "2pi"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args))
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
idx, off = addr_w_offset(args)
reg = newreg(newvar)
idx, treg, off = addr_w_offset(args)
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
if args.valid.max == 1:
@ -146,16 +172,16 @@ class AssemblyCodegen(Linearizer):
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared')))
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared')))
idx, treg, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
if DEBUG >= 4:
for tins in ins: print(tins)

View File

@ -0,0 +1,181 @@
import yaml
from typing import Tuple, Set, Dict
from tinygrad.helpers import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.linearizer import UOps
from tinygrad.ops import BinaryOps, UnaryOps, FusedOps
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
# ugh, is this really needed?
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()
boilerplate_start = """
.global _start
_start:
.rodata
.align 0x10
.global code.kd
.type code.kd,STT_OBJECT
.amdhsa_kernel code"""
code_start = """.end_amdhsa_kernel
.text
code:
"""
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
# RDNA3 is actually a SIMD machine!
class RDNACodegen(AssemblyCodegen):
supports_float4: bool = True
supports_float4_alu: bool = False
supports_load3: bool = True
sin_is_sin2pi: bool = True
no_div: bool = True
def specialize(self, asm) -> Tuple[str, str]:
args = []
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", FusedOps.MULACC: "fma",
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
BinaryOps.CMPEQ: "cmp_eq", BinaryOps.CMPLT: "cmp_lt"}
pend_regs:Set[Register] = set()
rtor:Dict[Register, str] = {}
def reg_in(x):
nonlocal pend_regs
#print("reg_in", x, rtor[x], pend_regs)
if x in pend_regs:
#print("clear")
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
pend_regs.clear()
return rtor[x]
def reg_out(x):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
if arg[0][0] == dtypes.uint64 and arg[0][1]:
# assuming these are scalar
s_cnt += s_cnt%2 # aligned(2)
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s[{s_cnt}:{s_cnt+1}]"
s_cnt += 2
elif arg[0][0] == dtypes._float4 and not arg[0][1]:
v_cnt += (4-v_cnt%4) if v_cnt%4 != 0 else 0
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v[{v_cnt}:{v_cnt+3}]"
for off in range(4): rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = f"v{v_cnt+off}"
v_cnt += 4
elif arg[0][0] in [dtypes.int32, dtypes.float32]:
for i in range(arg[2]):
if arg[0][1]:
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s{s_cnt}"
s_cnt += 1
else:
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v{v_cnt}"
v_cnt += 1
elif arg[0][0] == dtypes.bool and arg[0][1]:
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = "scc" if arg[0][1] else "vcc"
else:
raise NotImplementedError(arg)
elif uop == UOps.SPECIAL:
if arg.startswith('buf'):
i = int(arg[3:])
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif arg.startswith('gid'):
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
# the docs lied, this is actually y
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
# get local size
offset = len(args)*8
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
pend_regs.clear()
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
elif uop == UOps.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if out.dtype == dtypes._float4:
for off in range(4):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
elif uop == UOps.ALU:
if arg == BinaryOps.CMPLT:
if out.scalar:
ins.append(f"s_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
ins.append(f"v_cmp_lt_{dtype_to_rdnatype[out.dtype]} vcc, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
alu_arg = alu[arg]
if arg == FusedOps.MULACC and out == vin[2]:
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes._float4:
tins = []
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
tins.append(f"{'s_' if rr[0].scalar else 'v_'}dual_{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
ins.append(tins[0] + " :: " + tins[1])
ins.append(tins[2] + " :: " + tins[3])
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
elif uop == UOps.LOAD:
if out.scalar:
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
else:
ins.append(f'global_load_{"b128" if out.dtype == dtypes._float4 else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif uop == UOps.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes._float4 else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
else:
raise NotImplementedError(uop)
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
def assemble(self, args, ins, v_cnt, s_cnt):
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
metadata = {'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
return asm

View File

@ -172,8 +172,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
if lang.double_prekernel and any(x.dtype == dtypes.float64 for x in bufs): prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
return prg, global_size, local_size
class CStyleCodegen(Linearizer):

View File

@ -19,6 +19,7 @@ def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if n
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
@ -74,7 +75,7 @@ class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64)
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float4)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
@staticmethod
@ -87,10 +88,10 @@ class dtypes:
float64: Final[DType] = DType(5, 8, "double", np.float64)
int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32)
int64: Final[DType] = DType(2, 8, "int64", np.int64)
int64: Final[DType] = DType(2, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
uint64: Final[DType] = DType(2, 8, "ulong", np.uint64)
# NOTE: these are internal dtypes, should probably check for that
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)

View File

@ -9,7 +9,8 @@ from tinygrad.runtime.lib import RawBuffer, RawConst
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702

View File

@ -4,11 +4,10 @@ import numpy as np
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv
from tinygrad.helpers import DEBUG, getenv, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
from tinygrad.codegen.assembly_ptx import PTXCodegen
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
@ -60,4 +59,5 @@ class CUDACodegen(CStyleCodegen):
typedef long long int64;
""")
supports_float4_alu = False
CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)
CUDABuffer = Compiled(RawCUDABuffer, fromimport("tinygrad.codegen.assembly_ptx", "PTXCodegen") if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)

View File

@ -3,7 +3,7 @@ import pathlib
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
@ -61,7 +61,7 @@ class CLProgram:
if 'Adreno' in CL.cl_ctx.devices[0].name:
from disassemblers.adreno import disasm
disasm(self.binary())
elif 'gfx1100' in CL.cl_ctx.devices[0].name:
elif CL.cl_ctx.devices[0].name.startswith('gfx'):
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
else:
@ -87,11 +87,12 @@ class CLProgram:
class CLCodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "#define int64 long\n__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
supports_float4_alu = True
supports_float4 = True
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize)
GPUBuffer = Compiled(CLBuffer, fromimport("tinygrad.codegen.assembly_rdna", "RDNACodegen") if getenv("RDNA") else CLCodegen, CLProgram, CL.synchronize)

View File

@ -80,7 +80,7 @@ class MetalProgram:
class MetalCodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "#include <metal_stdlib>;\n#define int64 long\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4",
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])