diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ac1e08db..ce82453c6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -267,3 +267,28 @@ jobs: - name: Run pytest (cuda) if: matrix.backend=='cuda' run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models + + testunicorn: + name: ARM64 unicorn Test + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Cache pip + uses: actions/cache@v3 + with: + path: '~/.cache/pip' + key: unicorn + - name: Install cross-assembler + run: | + sudo apt-get update -y && \ + sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu + - name: Install dependencies + run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu + - name: Test arm + run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py \ No newline at end of file diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index 978119d03..d2b9fedb0 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -1,7 +1,7 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict -from tinygrad.codegen.linearizer import Linearizer, UOps, Token -from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps -from tinygrad.helpers import DType, dtypes, DEBUG +from tinygrad.codegen.linearizer import Linearizer, UOps, Token, ConstOp, MemOp, UOp +from tinygrad.ops import BinaryOps, UnaryOps +from tinygrad.helpers import DType, dtypes, DEBUG, getenv from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode import functools import math @@ -20,6 +20,7 @@ class Register(NamedTuple): if self.dtype == dtypes._float4: return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] return [] + class AssemblyInstruction(NamedTuple): op: UOps out: Optional[Register] @@ -27,156 +28,161 @@ class AssemblyInstruction(NamedTuple): arg: Any = None # warp size of 32, s registers are shared across the warp, v are 32-wide vectors -class AssemblyCodegen(Linearizer): +class AssemblyLanguage(NamedTuple): supports_load3: bool = False sin_is_sin2pi: bool = False no_div: bool = False + #TODO: these should be global vars + cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) + tor: Dict[Any, Register] = {} + ins = [] - def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]: - raise NotImplementedError("must be implemented") + def newreg(self, tok, dtype=dtypes.float32, scalar=False): + if isinstance(tok, Token): dtype = tok.dtype # this + self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar) + if dtype == dtypes._float4: + for off in range(4): + self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off) + self.cnts[(dtype, scalar)] += 1 + return ret - # s registers are the addresses and non local indexes - def codegen(self): - self.process() - self.hand_coded_optimizations() - self.limit_global_dims(3) # all GPU asms have 3 (for now) - self.linearize() + def render_numnode(self, b): + key = ("num", b) + if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) + return self.tor[key] - cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) - tor: Dict[Any, Register] = {} - def newreg(tok, dtype=dtypes.float32, scalar=False): - nonlocal cnts, tor - if isinstance(tok, Token): dtype = tok.dtype # this - tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar) - if dtype == dtypes._float4: - for off in range(4): - tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off) - cnts[(dtype, scalar)] += 1 - return ret + def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: + key = (op, a, b) + if key not in self.tor: + #if not isinstance(b, Register): b = render_numnode(b) + self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) + return self.tor[key] - def render_numnode(b): - key = ("num", b) - if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b)) - return tor[key] + def render_cast(self, a:Register, new_dtype:DType) -> Register: + if a.dtype == new_dtype: return a + key = (a, new_dtype) + if key not in self.tor: + self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])) + return self.tor[key] - def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: - key = (op, a, b) - if key not in tor: - #if not isinstance(b, Register): b = render_numnode(b) - ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) - return tor[key] + render_ops = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), + MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), + DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), + ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), + LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), + SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } - def render_cast(a:Register, new_dtype:DType) -> Register: - if a.dtype == new_dtype: return a - key = (a, new_dtype) - if key not in tor: - ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a])) - return tor[key] + def addr_w_offset(self, args): + assert isinstance(args, MemOp) + idx = args.idx*args.memory_dtype.itemsize + off = 0 # TODO: should this be None? + if isinstance(idx, SumNode): + nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] + if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? + idx -= nums[0] + off = nums[0] + reg = idx.render(self.render_ops, self) + if self.supports_load3: + if reg.scalar: + new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype) + self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) + reg = new_reg + return self.tor[args.name], reg, off + reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64) + return reg, None, off - render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b), - MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), - DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), - ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), - LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } +def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): + #TODO: Do not use clear() + lang.ins.clear() + lang.tor.clear() + buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL} + buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} + global_size, local_size = [], [] + skipload_branch = 0 + lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] + for uop,newvar,vin,args in uops: + if uop == UOps.DEFINE_LOCAL: + lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) + lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) + elif uop == UOps.LOOP: + if args[1] == "global": + for i,var in enumerate(args[0]): + global_size.append(var.max+1) + lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) + elif args[1] == "local": + for i,var in enumerate(args[0]): + local_size.append(var.max+1) + lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) + else: + for var in args[0]: + if not isinstance(var, NumNode): # TODO: why is this coming through? + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here? + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) + elif uop == UOps.ENDLOOP: + if args[1] not in ["global", "local", "global+local"]: + for var in reversed(args[0]): + if not isinstance(var, NumNode): # TODO: why is this coming through? + lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) + pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) + elif args[1] == "global+local": + for i, var in enumerate(reversed(args[0])): + lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) - def addr_w_offset(args): - idx = args.idx*self.bufs[args.i].dtype.itemsize - off = 0 # TODO: should this be None? - if isinstance(idx, SumNode): - nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] - if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? - idx -= nums[0] - off = nums[0] - reg = idx.render(render_ops) - if self.supports_load3: - if reg.scalar: - new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype) - ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) - reg = new_reg - return tor[f"buf{args.i}"], reg, off - reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64) - return reg, None, off - - ins = [] - ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))] - global_size, local_size = [], [] - skipload_branch = 0 - for uop,newvar,vin,args in self.uops: - if uop == UOps.CONST and newvar is not None: - ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args)) - elif uop == UOps.DEFINE_LOCAL: - ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) - ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) - elif uop == UOps.LOOP: - if args[1] == "global": - for i,var in enumerate(args[0]): - global_size.append(var.max+1) - ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) - elif args[1] == "local": - for i,var in enumerate(args[0]): - local_size.append(var.max+1) - ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) - else: - for var in args[0]: - if not isinstance(var, NumNode): # TODO: why is this coming through? - ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) - ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) - elif uop == UOps.ENDLOOP: - if args[1] not in ["global", "local", "global+local"]: - for var in reversed(args[0]): - if not isinstance(var, NumNode): # TODO: why is this coming through? - ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD)) - pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool) - ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) - elif uop == UOps.CAST and newvar is not None: - # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies - out = newreg(newvar) - for i,sr in enumerate(out.subregs()): - ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP)) - elif uop == UOps.ALU and newvar is not None: - out = newreg(newvar) if newvar not in tor else tor[newvar] - # this is the only thing that can violate SSA - if args in [BinaryOps.CMPLT]: - pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) - ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) - ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) - elif args == BinaryOps.DIV and self.no_div: - tmp = newreg((newvar, "rcp")) - ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP)) - ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL)) - elif args == UnaryOps.SIN and self.sin_is_sin2pi: - tmp = newreg((newvar, "2pi")) - ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) - else: - ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args)) - elif uop == UOps.LOAD and newvar is not None: - idx, treg, off = addr_w_offset(args) - reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype))) - if args.valid.min == 0: - ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0)) - if args.valid.max == 1: - pred = args.valid.render(render_ops) - ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - if args.valid.max == 1: - # NOTE: you can't compute the index in here, because it assumes it's all available later - ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) + elif uop == UOps.CAST and newvar is not None: + # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies + out = lang.newreg(newvar) + for i,sr in enumerate(out.subregs()): + lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) + elif uop == UOps.ALU and newvar is not None: + out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar] + # this is the only thing that can violate SSA + if args in [BinaryOps.CMPLT]: + pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool) + lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args)) + lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) + elif args == BinaryOps.DIV and lang.no_div: + tmp = lang.newreg((newvar, "rcp")) + lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) + elif args == UnaryOps.SIN and lang.sin_is_sin2pi: + tmp = lang.newreg((newvar, "2pi")) + lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) + else: + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)) + elif uop == UOps.LOAD and newvar is not None: + if isinstance(args, ConstOp): if args.valid.min == 0 and args.valid.max == 1: - ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value)) + pred = args.valid.render(lang.render_ops, lang) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value)) + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) skipload_branch += 1 - elif uop == UOps.STORE: - idx, treg, off = addr_w_offset(args) - ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) + else: + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value)) + else: + idx, treg, off = lang.addr_w_offset(args) + reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype))) + if args.valid.min == 0: + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) + if args.valid.max == 1: + pred = args.valid.render(lang.render_ops, lang) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + if args.valid.max == 1: + # NOTE: you can't compute the index in here, because it assumes it's all available later + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None))) + if args.valid.min == 0 and args.valid.max == 1: + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + skipload_branch += 1 + elif uop == UOps.STORE: + idx, treg, off = lang.addr_w_offset(args) + lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None))) + # define registers + lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins - # define registers - ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins - - if DEBUG >= 4: - for tins in ins: print(tins) - name, asm = self.specialize(ins) - - return ASTRunner(name, asm, - global_size[::-1], local_size[::-1], - op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True}) + if DEBUG >= 4: + for tins in lang.ins: print(tins) + return global_size, local_size \ No newline at end of file diff --git a/extra/assembly/assembly_arm64.py b/extra/assembly/assembly_arm64.py new file mode 100644 index 000000000..ccdeb9abf --- /dev/null +++ b/extra/assembly/assembly_arm64.py @@ -0,0 +1,171 @@ +import struct +from platform import system +from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register +from typing import Tuple, Set, Dict, List +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps +from tinygrad.codegen.linearizer import UOps, ConstOp, UOp +from tinygrad.helpers import dtypes, CI + +def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) +def compute_offsets(total): + quotient, remainder = divmod(total, 4096) + return [4096]*quotient + [remainder] if remainder else [4096]*quotient + +#NOTE: Darwin needs names to start with a "_" +def get_name(name): return ('_' if system() == 'Darwin' else '') + name + +class ARM64Language(AssemblyLanguage): pass + +def specialize_to_arm64(fn_nm, asm): + var_size = 16 + prev_uop = None + ins = [] + x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] + s_regs = ['s' + str(i) for i in reversed(range(3,30))] + type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} + alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", + BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", + UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), + TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} + + def mov_imm(value, reg): + # Manually move value into reg if value can't fit + if value.__class__ is not float and abs(value) > abs(65535): + ins.append(f"movz w15, #{value & 0xffff}") + ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") + ins.append(f"sxtw {reg}, w15") + elif reg[0] == 's': + ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") + ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") + ins.append(f"str x15, [sp, 16]") + ins.append(f"ldr {reg}, [sp, 16]") + else: + ins.append(f"mov {reg}, #{value}") + + # Get variables intervals + live_range:Dict[str, str] = {} + for i, (uop, out, vin, arg) in enumerate(asm): + for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): + live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] + + mem_vars:Dict[str, str] = {} + rtor:Dict[str, str] = {} + def allocate_regs(vars): + nonlocal var_size + for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]: + available_regs = s_regs if dtypes.is_float(v[1]) else x_regs + #NOTE: Very simple spill, everything that don't fit in regs goes to mem + if len(available_regs) == 0: + # ARM needs the stack 16-byte aligned + var_size += 16 + available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11') + mem_vars[v.nm] = var_size + rtor[v.nm] = available_regs.pop() + + temp_floats = ['s0', 's1', 's2'] + temp_ints = ['x11', 'x12', 'x13'] + for i, (uop, out, vin, arg) in enumerate(asm): + # Clear regs out of interval + for var, reg in list(rtor.items()): + available_regs = s_regs if reg[0] == 's' else x_regs + if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]: + available_regs.append(rtor.pop(var)) + # Assign a registers to the variables using live ranges. + allocate_regs([out] + vin) + # Assign temp regs to vin and load them before direct use + for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): + rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] + # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 + ins.append(f"mov x15, {mem_vars[v.nm]}") + ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") + + if uop == UOps.SPECIAL: + if arg.startswith('data'): + # data 8 to n into the stack + if int(arg[4:]) >= 8: + ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]") + ins.append(f"mov {rtor[out.nm]}, x15") + else: + ins.append(f"mov {rtor[out.nm]}, #0") + ins.append(f"loop_{arg}:") + elif uop == UOps.CAST: + if arg == BinaryOps.CMPLT: + mov_imm(0.0, 's0') + mov_imm(1.0, 's1') + ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") + else: + ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") + elif uop == UOps.ALU: + if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") + elif arg == TernaryOps.WHERE: + ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0") + ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") + elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: + #NOTE: Not a real instruction, use to emulate a ext call in unicorn + if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") + else: + save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] + ins.append(f"sub sp, sp, #{(len(save_regs))*16}") + # Save the registers before they are cleared by func call + for i,k in enumerate(save_regs,1): + ins.append(f"str {rtor[k]}, [sp, #{16*i}]") + ins.append("stp x29, x30, [sp, #0]!") + ins.append("mov x29, sp") + ins.append(f"fmov s0, {rtor[vin[0].nm]}") + ins.append(alu[arg]) + ins.append(f"fmov {rtor[out.nm]}, s0") + ins.append("mov sp, x29") + ins.append("ldp x29, x30, [sp], #0") + for i,k in enumerate(save_regs,1): + ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") + ins.append(f"add sp, sp, #{len(save_regs)*16}") + elif arg == BinaryOps.CMPLT: + ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") + elif arg == BinaryOps.MOD: + ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15") + ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}") + else: + ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") + elif uop == UOps.LOAD: + if arg.__class__ in (int, float): + mov_imm(arg, rtor[out.nm]) + else: + #NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm] + mov_imm(arg[0], "x15") + ins.append(f"add x15, {rtor[vin[0].nm]}, x15") + ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") + if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}") + elif uop == UOps.STORE: + shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"} + #NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) + if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}") + ins.append(f"mov x15, #{arg[0]}") + ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]") + elif uop == UOps.COND_BRANCH: + #TODO: this is a hack it shouldn't always be a cmp before a cond branch? + if prev_uop == UOps.LOAD: + ins.append(f"cmp {rtor[vin[0].nm]}, #0") + ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") + elif uop == UOps.LABEL: + ins.append(f"{arg[1:]}:") + elif uop == UOps.ENDLOOP: + mov_imm(arg[0], "x15") + ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") + ins.append(f"cmp {rtor[vin[0].nm]}, x15") + ins.append(f"b.lt loop_{arg[1]}") + + prev_uop=uop + # store regs into memory if needed + if out is not None and out.nm in mem_vars: + ins.append(f"mov x15, {mem_vars[out.nm]}") + ins.append(f"str {rtor[out.nm]}, [sp, x15]") + return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) + +def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: + lang = ARM64Language() + global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) + return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True \ No newline at end of file diff --git a/setup.py b/setup.py index 826b63fe7..33f1e6574 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ setup(name='tinygrad', extras_require={ 'llvm': ["llvmlite"], 'cuda': ["pycuda"], + 'arm': ["unicorn"], 'triton': ["triton>=2.0.0.dev20221202"], 'webgpu': ["wgpu"], 'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"], diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 233bfdad5..c009119e0 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod from tinygrad.runtime.lib import RawBuffer class FakeProgram: - def __init__(self, name:str, prg:str): pass + def __init__(self, name:str, prg:str, binary:bool): pass def __call__(self, global_size, local_size, *bufs, wait=False): pass class RawFakeBuffer(RawBuffer): diff --git a/test/test_uops.py b/test/test_uops.py index 8771d88bd..766ad2e31 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -1,14 +1,15 @@ import unittest, math import numpy as np -from tinygrad.helpers import dtypes +from tinygrad.helpers import dtypes, getenv from tinygrad.tensor import Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp from tinygrad.shape.symbolic import Variable def _uops_to_prg(uops): - src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime) + ret = Device[Device.DEFAULT].renderer("test", uops) + src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) + return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime) def _test_single_value(tc, tt, vals, op): uops = [ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c49a5e3d9..aba53b2ce 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -156,10 +156,12 @@ class Compiled: def to_program(self, k): k.linearize() - src, global_size, local_size = self.renderer(k.function_name, k.uops) + ret = self.renderer(k.function_name, k.uops) + src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) + #TODO: I need to find a better way to select ARM64 return ASTRunner(k.function_name, src, global_size, local_size, op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name).build(self.runtime) + display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime) def exec_ast(self, ast:LazyOp, output, **kwargs): # all movementops do nothing in a Compiled buffer! diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index b6f7fb90a..447edd501 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,8 +1,15 @@ import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools +from functools import partial, reduce from tinygrad.ops import Compiled +from tinygrad.helpers import fromimport, getenv, DEBUG, CI from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage +import struct +import numpy as np + +ARM64 = getenv('ARM64', False) +if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const args = { 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'}, @@ -11,24 +18,64 @@ args = { }[platform.system()] CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n' +ADDRESS = 0x10000 + +# Unicorn doesn't support external calls +def align(addr): return (addr+4095) & ~(4095) +mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2} +def emulate_ext_calls(fn, uc, address, size, user_data): + s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0] + uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore + class ClangProgram: - def __init__(self, name:str, prg:str): - prg = CLANG_PROGRAM_HEADER + prg + def __init__(self, name:str, prg:str, binary:bool=False): # TODO: is there a way to not write this to disk? # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}" if not os.path.exists(fn): tmp = f"{fn}.{os.getpid()}.tmp" - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) - os.rename(tmp, fn) + if not binary: + prg = CLANG_PROGRAM_HEADER + prg + subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) + os.rename(tmp, fn) + else: + if DEBUG >= 5: print(prg) + if CI and ARM64: + prg = prg.split('\n') # type: ignore + self.varsize = align(int(prg[0].split(" ")[1])) + self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'} + prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n']) + subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split()) + self.prg = open(fn + '.bin', 'rb').read() + return + subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split()) self.lib = ctypes.CDLL(fn) self.fxn = self.lib[name] - def __call__(self, global_size, local_size, *args, wait=False): if wait: st = time.monotonic() - self.fxn(*[x._buf for x in args]) + if CI and ARM64: + mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM) + total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize)) + mu.mem_map(ADDRESS, total_mem) + for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k) + mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args)) + addr = ADDRESS + len(self.prg) + for i, arg in enumerate(args): + if i<=7: + mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr) + else: + # NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone. + mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little')) + addr += arg.size * arg.dtype.itemsize + mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8) + mu.emu_start(ADDRESS, ADDRESS + len(self.prg)) + args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize) + else: + self.fxn(*[x._buf for x in args]) if wait: return time.monotonic()-st -renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) +renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 918878c0e..4c5e564d1 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -37,7 +37,7 @@ def unwrap(x): return ret class MetalProgram: - def __init__(self, name:str, prg:str): + def __init__(self, name:str, prg:str, binary:bool=False): if METAL_XCODE: air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode