1
0
Fork 0
* testing new memops

* better debugging

* testing padded conv

* branching with load

* refactoring a bit

* first try

* fixing bugs

* fixing some

* eq

* eq2

* do not use x's

* working

* fixing imm

* getting things working

* refactor

* pow not working

* working except one

* refactor: one store mem

* refactor: global load

* refactor: imm

* refactor: cleaning

* fixing big offsets

* refactor with ci

* try ci

* typo

* another typo

* ubuntu default

* forgot git

* do i need git?

* missing packages

* adding python-dev

* with cache?

* buildx action

* buildx name issue?

* maybe now?

* python3

* newline warning

* maybe now

* i actually need this

* ci should work now

* improved caching

* fixing cache

* maybe now it will cache

* this

* testing cache

* trying again

* load

* missing platform

* caching gha

* testing cache

* full testing

* typo

* now?

* why

* adding checkout back

* bad formatting

* fixing convention issues

* supporting python

* adding CI flag

* testing all

* better comments

* adding debugging

* takes 12x longer

* does it output progress now?

* ignore models for speed

* fixing merge

* excluding conv_transpose2d

* only 2 test cuz is to slow

* another approach

* let's see

* faster duh

* my bad

* T_T

* typo

* sup

* with output?

* comment test

* comment test

* comment test

* :?

* no comment

* with cache

* back to normal

* testing that ci works

* back to passing

* trying again

* does it create another entry

* does it create another entry?

* build local

* hey

* Revert "excluding conv_transpose2d"

This reverts commit cc7348de03.

* does it cache if done before?

* does it cache?

* done

* adding test ops

* bad formatting

* no need for this

* working static mem

* sum 1d

* add ndim

* better reg import

* fix stack

* back to np

* working except for softmax

* 5 failing

* no pogress

* remove keystone

* remove keystone

* testops passing

* cleanups

* more cleanup

* typo

* ci

* ci2

* cond import

* ci3

* ci4

* ci4

* ci5

* ci5

* ci6

* aligment

* test all

* correct test

* err read_unmapped

* passing test

* ignore for speed

* ignore for speed

* ci7

* cleanup

* remove docker

* fixing merge

* fixing bugs

* add skipload for const ops

* comments

* First merge to master: Renderer

* fix emulation

* passing all tests arm64

* cleaning

* fix handcoded binary

* cleaning

* fix errs

* fix runtime arg binary

* clean git diff

* fix and clean

* fixing metal test

* cleaning

* fix metal test

* ci ~8 min

* fix pylint and clang

* cache the files in ops_clang

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
pull/1523/head^2
Steven Anderson 2023-08-14 22:29:30 -04:00 committed by GitHub
parent a89142e46f
commit 93a36c3659
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 409 additions and 156 deletions

View File

@ -267,3 +267,28 @@ jobs:
- name: Run pytest (cuda)
if: matrix.backend=='cuda'
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models
testunicorn:
name: ARM64 unicorn Test
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Cache pip
uses: actions/cache@v3
with:
path: '~/.cache/pip'
key: unicorn
- name: Install cross-assembler
run: |
sudo apt-get update -y && \
sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu
- name: Install dependencies
run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test arm
run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py

View File

@ -1,7 +1,7 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.codegen.linearizer import Linearizer, UOps, Token, ConstOp, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG, getenv
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
@ -20,6 +20,7 @@ class Register(NamedTuple):
if self.dtype == dtypes._float4:
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
@ -27,156 +28,161 @@ class AssemblyInstruction(NamedTuple):
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyCodegen(Linearizer):
class AssemblyLanguage(NamedTuple):
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
#TODO: these should be global vars
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
ins = []
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
raise NotImplementedError("must be implemented")
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
if isinstance(tok, Token): dtype = tok.dtype # this
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
self.cnts[(dtype, scalar)] += 1
return ret
# s registers are the addresses and non local indexes
def codegen(self):
self.process()
self.hand_coded_optimizations()
self.limit_global_dims(3) # all GPU asms have 3 (for now)
self.linearize()
def render_numnode(self, b):
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return self.tor[key]
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
def newreg(tok, dtype=dtypes.float32, scalar=False):
nonlocal cnts, tor
if isinstance(tok, Token): dtype = tok.dtype # this
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
cnts[(dtype, scalar)] += 1
return ret
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in self.tor:
#if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return self.tor[key]
def render_numnode(b):
key = ("num", b)
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return tor[key]
def render_cast(self, a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
return self.tor[key]
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return tor[key]
render_ops = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def render_cast(a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in tor:
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
return tor[key]
def addr_w_offset(self, args):
assert isinstance(args, MemOp)
idx = args.idx*args.memory_dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = nums[0]
reg = idx.render(self.render_ops, self)
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
return reg, None, off
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
#TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
for uop,newvar,vin,args in uops:
if uop == UOps.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
def addr_w_offset(args):
idx = args.idx*self.bufs[args.i].dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = nums[0]
reg = idx.render(render_ops)
if self.supports_load3:
if reg.scalar:
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return tor[f"buf{args.i}"], reg, off
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, None, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
global_size, local_size = [], []
skipload_branch = 0
for uop,newvar,vin,args in self.uops:
if uop == UOps.CONST and newvar is not None:
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
elif uop == UOps.DEFINE_LOCAL:
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = newreg(newvar)
for i,sr in enumerate(out.subregs()):
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
out = newreg(newvar) if newvar not in tor else tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and self.no_div:
tmp = newreg((newvar, "rcp"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
tmp = newreg((newvar, "2pi"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
idx, treg, off = addr_w_offset(args)
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(render_ops)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(newvar)
for i,sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((newvar, "rcp"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((newvar, "2pi"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
if isinstance(args, ConstOp):
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
else:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value))
else:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
# define registers
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
if DEBUG >= 4:
for tins in ins: print(tins)
name, asm = self.specialize(ins)
return ASTRunner(name, asm,
global_size[::-1], local_size[::-1],
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})
if DEBUG >= 4:
for tins in lang.ins: print(tins)
return global_size, local_size

View File

@ -0,0 +1,171 @@
import struct
from platform import system
from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register
from typing import Tuple, Set, Dict, List
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.linearizer import UOps, ConstOp, UOp
from tinygrad.helpers import dtypes, CI
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
class ARM64Language(AssemblyLanguage): pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
def mov_imm(value, reg):
# Manually move value into reg if value can't fit
if value.__class__ is not float and abs(value) > abs(65535):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append(f"str x15, [sp, 16]")
ins.append(f"ldr {reg}, [sp, 16]")
else:
ins.append(f"mov {reg}, #{value}")
# Get variables intervals
live_range:Dict[str, str] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
mem_vars:Dict[str, str] = {}
rtor:Dict[str, str] = {}
def allocate_regs(vars):
nonlocal var_size
for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if len(available_regs) == 0:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x11', 'x12', 'x13']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == UOps.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == UOps.CAST:
if arg == BinaryOps.CMPLT:
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == UOps.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i,k in enumerate(save_regs,1):
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
for i,k in enumerate(save_regs,1):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == UOps.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == UOps.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
ins.append(f"b.lt loop_{arg[1]}")
prev_uop=uop
# store regs into memory if needed
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True

View File

@ -24,6 +24,7 @@ setup(name='tinygrad',
extras_require={
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'arm': ["unicorn"],
'triton': ["triton>=2.0.0.dev20221202"],
'webgpu': ["wgpu"],
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],

View File

@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
from tinygrad.runtime.lib import RawBuffer
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __init__(self, name:str, prg:str, binary:bool): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
class RawFakeBuffer(RawBuffer):

View File

@ -1,14 +1,15 @@
import unittest, math
import numpy as np
from tinygrad.helpers import dtypes
from tinygrad.helpers import dtypes, getenv
from tinygrad.tensor import Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
from tinygrad.shape.symbolic import Variable
def _uops_to_prg(uops):
src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime)
ret = Device[Device.DEFAULT].renderer("test", uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
def _test_single_value(tc, tt, vals, op):
uops = [

View File

@ -156,10 +156,12 @@ class Compiled:
def to_program(self, k):
k.linearize()
src, global_size, local_size = self.renderer(k.function_name, k.uops)
ret = self.renderer(k.function_name, k.uops)
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
#TODO: I need to find a better way to select ARM64
return ASTRunner(k.function_name, src, global_size, local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name).build(self.runtime)
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!

View File

@ -1,8 +1,15 @@
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
from functools import partial, reduce
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np
ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const
args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
@ -11,24 +18,64 @@ args = {
}[platform.system()]
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
ADDRESS = 0x10000
# Unicorn doesn't support external calls
def align(addr): return (addr+4095) & ~(4095)
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
def emulate_ext_calls(fn, uc, address, size, user_data):
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
class ClangProgram:
def __init__(self, name:str, prg:str):
prg = CLANG_PROGRAM_HEADER + prg
def __init__(self, name:str, prg:str, binary:bool=False):
# TODO: is there a way to not write this to disk?
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
if not os.path.exists(fn):
tmp = f"{fn}.{os.getpid()}.tmp"
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
if not binary:
prg = CLANG_PROGRAM_HEADER + prg
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
else:
if DEBUG >= 5: print(prg)
if CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
self.prg = open(fn + '.bin', 'rb').read()
return
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]
def __call__(self, global_size, local_size, *args, wait=False):
if wait: st = time.monotonic()
self.fxn(*[x._buf for x in args])
if CI and ARM64:
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
mu.mem_map(ADDRESS, total_mem)
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
addr = ADDRESS + len(self.prg)
for i, arg in enumerate(args):
if i<=7:
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
else:
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
addr += arg.size * arg.dtype.itemsize
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf for x in args])
if wait: return time.monotonic()-st
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@ -37,7 +37,7 @@ def unwrap(x):
return ret
class MetalProgram:
def __init__(self, name:str, prg:str):
def __init__(self, name:str, prg:str, binary:bool=False):
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode