Arm (#1421)
* testing new memops
* better debugging
* testing padded conv
* branching with load
* refactoring a bit
* first try
* fixing bugs
* fixing some
* eq
* eq2
* do not use x's
* working
* fixing imm
* getting things working
* refactor
* pow not working
* working except one
* refactor: one store mem
* refactor: global load
* refactor: imm
* refactor: cleaning
* fixing big offsets
* refactor with ci
* try ci
* typo
* another typo
* ubuntu default
* forgot git
* do i need git?
* missing packages
* adding python-dev
* with cache?
* buildx action
* buildx name issue?
* maybe now?
* python3
* newline warning
* maybe now
* i actually need this
* ci should work now
* improved caching
* fixing cache
* maybe now it will cache
* this
* testing cache
* trying again
* load
* missing platform
* caching gha
* testing cache
* full testing
* typo
* now?
* why
* adding checkout back
* bad formatting
* fixing convention issues
* supporting python
* adding CI flag
* testing all
* better comments
* adding debugging
* takes 12x longer
* does it output progress now?
* ignore models for speed
* fixing merge
* excluding conv_transpose2d
* only 2 test cuz is to slow
* another approach
* let's see
* faster duh
* my bad
* T_T
* typo
* sup
* with output?
* comment test
* comment test
* comment test
* :?
* no comment
* with cache
* back to normal
* testing that ci works
* back to passing
* trying again
* does it create another entry
* does it create another entry?
* build local
* hey
* Revert "excluding conv_transpose2d"
This reverts commit cc7348de03
.
* does it cache if done before?
* does it cache?
* done
* adding test ops
* bad formatting
* no need for this
* working static mem
* sum 1d
* add ndim
* better reg import
* fix stack
* back to np
* working except for softmax
* 5 failing
* no pogress
* remove keystone
* remove keystone
* testops passing
* cleanups
* more cleanup
* typo
* ci
* ci2
* cond import
* ci3
* ci4
* ci4
* ci5
* ci5
* ci6
* aligment
* test all
* correct test
* err read_unmapped
* passing test
* ignore for speed
* ignore for speed
* ci7
* cleanup
* remove docker
* fixing merge
* fixing bugs
* add skipload for const ops
* comments
* First merge to master: Renderer
* fix emulation
* passing all tests arm64
* cleaning
* fix handcoded binary
* cleaning
* fix errs
* fix runtime arg binary
* clean git diff
* fix and clean
* fixing metal test
* cleaning
* fix metal test
* ci ~8 min
* fix pylint and clang
* cache the files in ops_clang
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
pull/1523/head^2
parent
a89142e46f
commit
93a36c3659
|
@ -267,3 +267,28 @@ jobs:
|
||||||
- name: Run pytest (cuda)
|
- name: Run pytest (cuda)
|
||||||
if: matrix.backend=='cuda'
|
if: matrix.backend=='cuda'
|
||||||
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models
|
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models
|
||||||
|
|
||||||
|
testunicorn:
|
||||||
|
name: ARM64 unicorn Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 20
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.8
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: '~/.cache/pip'
|
||||||
|
key: unicorn
|
||||||
|
- name: Install cross-assembler
|
||||||
|
run: |
|
||||||
|
sudo apt-get update -y && \
|
||||||
|
sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
- name: Test arm
|
||||||
|
run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
||||||
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
|
from tinygrad.codegen.linearizer import Linearizer, UOps, Token, ConstOp, MemOp, UOp
|
||||||
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
|
from tinygrad.ops import BinaryOps, UnaryOps
|
||||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
from tinygrad.helpers import DType, dtypes, DEBUG, getenv
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
|
@ -20,6 +20,7 @@ class Register(NamedTuple):
|
||||||
if self.dtype == dtypes._float4:
|
if self.dtype == dtypes._float4:
|
||||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
class AssemblyInstruction(NamedTuple):
|
class AssemblyInstruction(NamedTuple):
|
||||||
op: UOps
|
op: UOps
|
||||||
out: Optional[Register]
|
out: Optional[Register]
|
||||||
|
@ -27,156 +28,161 @@ class AssemblyInstruction(NamedTuple):
|
||||||
arg: Any = None
|
arg: Any = None
|
||||||
|
|
||||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||||
class AssemblyCodegen(Linearizer):
|
class AssemblyLanguage(NamedTuple):
|
||||||
supports_load3: bool = False
|
supports_load3: bool = False
|
||||||
sin_is_sin2pi: bool = False
|
sin_is_sin2pi: bool = False
|
||||||
no_div: bool = False
|
no_div: bool = False
|
||||||
|
#TODO: these should be global vars
|
||||||
|
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||||
|
tor: Dict[Any, Register] = {}
|
||||||
|
ins = []
|
||||||
|
|
||||||
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
|
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
|
||||||
raise NotImplementedError("must be implemented")
|
if isinstance(tok, Token): dtype = tok.dtype # this
|
||||||
|
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||||
|
if dtype == dtypes._float4:
|
||||||
|
for off in range(4):
|
||||||
|
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||||
|
self.cnts[(dtype, scalar)] += 1
|
||||||
|
return ret
|
||||||
|
|
||||||
# s registers are the addresses and non local indexes
|
def render_numnode(self, b):
|
||||||
def codegen(self):
|
key = ("num", b)
|
||||||
self.process()
|
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||||
self.hand_coded_optimizations()
|
return self.tor[key]
|
||||||
self.limit_global_dims(3) # all GPU asms have 3 (for now)
|
|
||||||
self.linearize()
|
|
||||||
|
|
||||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||||
tor: Dict[Any, Register] = {}
|
key = (op, a, b)
|
||||||
def newreg(tok, dtype=dtypes.float32, scalar=False):
|
if key not in self.tor:
|
||||||
nonlocal cnts, tor
|
#if not isinstance(b, Register): b = render_numnode(b)
|
||||||
if isinstance(tok, Token): dtype = tok.dtype # this
|
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||||
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
|
return self.tor[key]
|
||||||
if dtype == dtypes._float4:
|
|
||||||
for off in range(4):
|
|
||||||
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
|
||||||
cnts[(dtype, scalar)] += 1
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def render_numnode(b):
|
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||||
key = ("num", b)
|
if a.dtype == new_dtype: return a
|
||||||
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
key = (a, new_dtype)
|
||||||
return tor[key]
|
if key not in self.tor:
|
||||||
|
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||||
|
return self.tor[key]
|
||||||
|
|
||||||
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
render_ops = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||||
key = (op, a, b)
|
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||||
if key not in tor:
|
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||||
#if not isinstance(b, Register): b = render_numnode(b)
|
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||||
return tor[key]
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||||
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||||
|
|
||||||
def render_cast(a:Register, new_dtype:DType) -> Register:
|
def addr_w_offset(self, args):
|
||||||
if a.dtype == new_dtype: return a
|
assert isinstance(args, MemOp)
|
||||||
key = (a, new_dtype)
|
idx = args.idx*args.memory_dtype.itemsize
|
||||||
if key not in tor:
|
off = 0 # TODO: should this be None?
|
||||||
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
|
if isinstance(idx, SumNode):
|
||||||
return tor[key]
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||||
|
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||||
|
idx -= nums[0]
|
||||||
|
off = nums[0]
|
||||||
|
reg = idx.render(self.render_ops, self)
|
||||||
|
if self.supports_load3:
|
||||||
|
if reg.scalar:
|
||||||
|
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||||
|
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||||
|
reg = new_reg
|
||||||
|
return self.tor[args.name], reg, off
|
||||||
|
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||||
|
return reg, None, off
|
||||||
|
|
||||||
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
|
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||||
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
#TODO: Do not use clear()
|
||||||
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
lang.ins.clear()
|
||||||
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
lang.tor.clear()
|
||||||
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
|
||||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
global_size, local_size = [], []
|
||||||
|
skipload_branch = 0
|
||||||
|
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||||
|
for uop,newvar,vin,args in uops:
|
||||||
|
if uop == UOps.DEFINE_LOCAL:
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||||
|
elif uop == UOps.LOOP:
|
||||||
|
if args[1] == "global":
|
||||||
|
for i,var in enumerate(args[0]):
|
||||||
|
global_size.append(var.max+1)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||||
|
elif args[1] == "local":
|
||||||
|
for i,var in enumerate(args[0]):
|
||||||
|
local_size.append(var.max+1)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||||
|
else:
|
||||||
|
for var in args[0]:
|
||||||
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||||
|
elif uop == UOps.ENDLOOP:
|
||||||
|
if args[1] not in ["global", "local", "global+local"]:
|
||||||
|
for var in reversed(args[0]):
|
||||||
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||||
|
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||||
|
elif args[1] == "global+local":
|
||||||
|
for i, var in enumerate(reversed(args[0])):
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||||
|
|
||||||
def addr_w_offset(args):
|
elif uop == UOps.CAST and newvar is not None:
|
||||||
idx = args.idx*self.bufs[args.i].dtype.itemsize
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||||
off = 0 # TODO: should this be None?
|
out = lang.newreg(newvar)
|
||||||
if isinstance(idx, SumNode):
|
for i,sr in enumerate(out.subregs()):
|
||||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||||
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
elif uop == UOps.ALU and newvar is not None:
|
||||||
idx -= nums[0]
|
out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar]
|
||||||
off = nums[0]
|
# this is the only thing that can violate SSA
|
||||||
reg = idx.render(render_ops)
|
if args in [BinaryOps.CMPLT]:
|
||||||
if self.supports_load3:
|
pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||||
if reg.scalar:
|
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||||
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
elif args == BinaryOps.DIV and lang.no_div:
|
||||||
reg = new_reg
|
tmp = lang.newreg((newvar, "rcp"))
|
||||||
return tor[f"buf{args.i}"], reg, off
|
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||||
return reg, None, off
|
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||||
|
tmp = lang.newreg((newvar, "2pi"))
|
||||||
ins = []
|
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||||
global_size, local_size = [], []
|
else:
|
||||||
skipload_branch = 0
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||||
for uop,newvar,vin,args in self.uops:
|
elif uop == UOps.LOAD and newvar is not None:
|
||||||
if uop == UOps.CONST and newvar is not None:
|
if isinstance(args, ConstOp):
|
||||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
|
|
||||||
elif uop == UOps.DEFINE_LOCAL:
|
|
||||||
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
|
||||||
elif uop == UOps.LOOP:
|
|
||||||
if args[1] == "global":
|
|
||||||
for i,var in enumerate(args[0]):
|
|
||||||
global_size.append(var.max+1)
|
|
||||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
|
||||||
elif args[1] == "local":
|
|
||||||
for i,var in enumerate(args[0]):
|
|
||||||
local_size.append(var.max+1)
|
|
||||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
|
||||||
else:
|
|
||||||
for var in args[0]:
|
|
||||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
||||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
|
||||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
|
||||||
elif uop == UOps.ENDLOOP:
|
|
||||||
if args[1] not in ["global", "local", "global+local"]:
|
|
||||||
for var in reversed(args[0]):
|
|
||||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
|
|
||||||
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
|
|
||||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
|
||||||
elif uop == UOps.CAST and newvar is not None:
|
|
||||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
|
||||||
out = newreg(newvar)
|
|
||||||
for i,sr in enumerate(out.subregs()):
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
|
|
||||||
elif uop == UOps.ALU and newvar is not None:
|
|
||||||
out = newreg(newvar) if newvar not in tor else tor[newvar]
|
|
||||||
# this is the only thing that can violate SSA
|
|
||||||
if args in [BinaryOps.CMPLT]:
|
|
||||||
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
|
||||||
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
|
||||||
elif args == BinaryOps.DIV and self.no_div:
|
|
||||||
tmp = newreg((newvar, "rcp"))
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
|
|
||||||
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
|
|
||||||
tmp = newreg((newvar, "2pi"))
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
|
||||||
else:
|
|
||||||
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
|
|
||||||
elif uop == UOps.LOAD and newvar is not None:
|
|
||||||
idx, treg, off = addr_w_offset(args)
|
|
||||||
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
|
||||||
if args.valid.min == 0:
|
|
||||||
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
|
|
||||||
if args.valid.max == 1:
|
|
||||||
pred = args.valid.render(render_ops)
|
|
||||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
|
||||||
if args.valid.max == 1:
|
|
||||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
|
||||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
|
||||||
if args.valid.min == 0 and args.valid.max == 1:
|
if args.valid.min == 0 and args.valid.max == 1:
|
||||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
|
||||||
|
pred = args.valid.render(lang.render_ops, lang)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||||
skipload_branch += 1
|
skipload_branch += 1
|
||||||
elif uop == UOps.STORE:
|
else:
|
||||||
idx, treg, off = addr_w_offset(args)
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value))
|
||||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
else:
|
||||||
|
idx, treg, off = lang.addr_w_offset(args)
|
||||||
|
reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
||||||
|
if args.valid.min == 0:
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||||
|
if args.valid.max == 1:
|
||||||
|
pred = args.valid.render(lang.render_ops, lang)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||||
|
if args.valid.max == 1:
|
||||||
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
|
||||||
|
if args.valid.min == 0 and args.valid.max == 1:
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||||
|
skipload_branch += 1
|
||||||
|
elif uop == UOps.STORE:
|
||||||
|
idx, treg, off = lang.addr_w_offset(args)
|
||||||
|
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
|
||||||
|
# define registers
|
||||||
|
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
|
||||||
|
|
||||||
# define registers
|
if DEBUG >= 4:
|
||||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
|
for tins in lang.ins: print(tins)
|
||||||
|
return global_size, local_size
|
||||||
if DEBUG >= 4:
|
|
||||||
for tins in ins: print(tins)
|
|
||||||
name, asm = self.specialize(ins)
|
|
||||||
|
|
||||||
return ASTRunner(name, asm,
|
|
||||||
global_size[::-1], local_size[::-1],
|
|
||||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
import struct
|
||||||
|
from platform import system
|
||||||
|
from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register
|
||||||
|
from typing import Tuple, Set, Dict, List
|
||||||
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||||
|
from tinygrad.codegen.linearizer import UOps, ConstOp, UOp
|
||||||
|
from tinygrad.helpers import dtypes, CI
|
||||||
|
|
||||||
|
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||||
|
def compute_offsets(total):
|
||||||
|
quotient, remainder = divmod(total, 4096)
|
||||||
|
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
||||||
|
|
||||||
|
#NOTE: Darwin needs names to start with a "_"
|
||||||
|
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||||
|
|
||||||
|
class ARM64Language(AssemblyLanguage): pass
|
||||||
|
|
||||||
|
def specialize_to_arm64(fn_nm, asm):
|
||||||
|
var_size = 16
|
||||||
|
prev_uop = None
|
||||||
|
ins = []
|
||||||
|
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
|
||||||
|
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
|
||||||
|
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||||
|
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||||
|
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||||
|
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||||
|
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||||
|
|
||||||
|
def mov_imm(value, reg):
|
||||||
|
# Manually move value into reg if value can't fit
|
||||||
|
if value.__class__ is not float and abs(value) > abs(65535):
|
||||||
|
ins.append(f"movz w15, #{value & 0xffff}")
|
||||||
|
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||||
|
ins.append(f"sxtw {reg}, w15")
|
||||||
|
elif reg[0] == 's':
|
||||||
|
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||||
|
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||||
|
ins.append(f"str x15, [sp, 16]")
|
||||||
|
ins.append(f"ldr {reg}, [sp, 16]")
|
||||||
|
else:
|
||||||
|
ins.append(f"mov {reg}, #{value}")
|
||||||
|
|
||||||
|
# Get variables intervals
|
||||||
|
live_range:Dict[str, str] = {}
|
||||||
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||||
|
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||||
|
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||||
|
|
||||||
|
mem_vars:Dict[str, str] = {}
|
||||||
|
rtor:Dict[str, str] = {}
|
||||||
|
def allocate_regs(vars):
|
||||||
|
nonlocal var_size
|
||||||
|
for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||||
|
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||||
|
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||||
|
if len(available_regs) == 0:
|
||||||
|
# ARM needs the stack 16-byte aligned
|
||||||
|
var_size += 16
|
||||||
|
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
|
||||||
|
mem_vars[v.nm] = var_size
|
||||||
|
rtor[v.nm] = available_regs.pop()
|
||||||
|
|
||||||
|
temp_floats = ['s0', 's1', 's2']
|
||||||
|
temp_ints = ['x11', 'x12', 'x13']
|
||||||
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||||
|
# Clear regs out of interval
|
||||||
|
for var, reg in list(rtor.items()):
|
||||||
|
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||||
|
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||||
|
available_regs.append(rtor.pop(var))
|
||||||
|
# Assign a registers to the variables using live ranges.
|
||||||
|
allocate_regs([out] + vin)
|
||||||
|
# Assign temp regs to vin and load them before direct use
|
||||||
|
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||||
|
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||||
|
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||||
|
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||||
|
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||||
|
|
||||||
|
if uop == UOps.SPECIAL:
|
||||||
|
if arg.startswith('data'):
|
||||||
|
# data 8 to n into the stack
|
||||||
|
if int(arg[4:]) >= 8:
|
||||||
|
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
|
||||||
|
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||||
|
else:
|
||||||
|
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||||
|
ins.append(f"loop_{arg}:")
|
||||||
|
elif uop == UOps.CAST:
|
||||||
|
if arg == BinaryOps.CMPLT:
|
||||||
|
mov_imm(0.0, 's0')
|
||||||
|
mov_imm(1.0, 's1')
|
||||||
|
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||||
|
else:
|
||||||
|
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||||
|
elif uop == UOps.ALU:
|
||||||
|
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||||
|
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||||
|
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||||
|
elif arg == TernaryOps.WHERE:
|
||||||
|
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0")
|
||||||
|
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||||
|
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||||
|
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||||
|
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||||
|
else:
|
||||||
|
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||||
|
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||||
|
# Save the registers before they are cleared by func call
|
||||||
|
for i,k in enumerate(save_regs,1):
|
||||||
|
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||||
|
ins.append("stp x29, x30, [sp, #0]!")
|
||||||
|
ins.append("mov x29, sp")
|
||||||
|
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||||
|
ins.append(alu[arg])
|
||||||
|
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||||
|
ins.append("mov sp, x29")
|
||||||
|
ins.append("ldp x29, x30, [sp], #0")
|
||||||
|
for i,k in enumerate(save_regs,1):
|
||||||
|
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||||
|
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||||
|
elif arg == BinaryOps.CMPLT:
|
||||||
|
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||||
|
elif arg == BinaryOps.MOD:
|
||||||
|
ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15")
|
||||||
|
ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}")
|
||||||
|
else:
|
||||||
|
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||||
|
elif uop == UOps.LOAD:
|
||||||
|
if arg.__class__ in (int, float):
|
||||||
|
mov_imm(arg, rtor[out.nm])
|
||||||
|
else:
|
||||||
|
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||||
|
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||||
|
mov_imm(arg[0], "x15")
|
||||||
|
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||||
|
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||||
|
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||||
|
elif uop == UOps.STORE:
|
||||||
|
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
|
||||||
|
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||||
|
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||||
|
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||||
|
ins.append(f"mov x15, #{arg[0]}")
|
||||||
|
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||||
|
elif uop == UOps.COND_BRANCH:
|
||||||
|
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||||
|
if prev_uop == UOps.LOAD:
|
||||||
|
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||||
|
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||||
|
elif uop == UOps.LABEL:
|
||||||
|
ins.append(f"{arg[1:]}:")
|
||||||
|
elif uop == UOps.ENDLOOP:
|
||||||
|
mov_imm(arg[0], "x15")
|
||||||
|
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||||
|
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||||
|
ins.append(f"b.lt loop_{arg[1]}")
|
||||||
|
|
||||||
|
prev_uop=uop
|
||||||
|
# store regs into memory if needed
|
||||||
|
if out is not None and out.nm in mem_vars:
|
||||||
|
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||||
|
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||||
|
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||||
|
|
||||||
|
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||||
|
lang = ARM64Language()
|
||||||
|
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||||
|
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
1
setup.py
1
setup.py
|
@ -24,6 +24,7 @@ setup(name='tinygrad',
|
||||||
extras_require={
|
extras_require={
|
||||||
'llvm': ["llvmlite"],
|
'llvm': ["llvmlite"],
|
||||||
'cuda': ["pycuda"],
|
'cuda': ["pycuda"],
|
||||||
|
'arm': ["unicorn"],
|
||||||
'triton': ["triton>=2.0.0.dev20221202"],
|
'triton': ["triton>=2.0.0.dev20221202"],
|
||||||
'webgpu': ["wgpu"],
|
'webgpu': ["wgpu"],
|
||||||
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
|
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
|
||||||
|
|
|
@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
|
||||||
from tinygrad.runtime.lib import RawBuffer
|
from tinygrad.runtime.lib import RawBuffer
|
||||||
|
|
||||||
class FakeProgram:
|
class FakeProgram:
|
||||||
def __init__(self, name:str, prg:str): pass
|
def __init__(self, name:str, prg:str, binary:bool): pass
|
||||||
def __call__(self, global_size, local_size, *bufs, wait=False): pass
|
def __call__(self, global_size, local_size, *bufs, wait=False): pass
|
||||||
|
|
||||||
class RawFakeBuffer(RawBuffer):
|
class RawFakeBuffer(RawBuffer):
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
import unittest, math
|
import unittest, math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes, getenv
|
||||||
from tinygrad.tensor import Device
|
from tinygrad.tensor import Device
|
||||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
||||||
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
|
from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad.shape.symbolic import Variable
|
||||||
|
|
||||||
def _uops_to_prg(uops):
|
def _uops_to_prg(uops):
|
||||||
src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops)
|
ret = Device[Device.DEFAULT].renderer("test", uops)
|
||||||
return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime)
|
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||||
|
return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime)
|
||||||
|
|
||||||
def _test_single_value(tc, tt, vals, op):
|
def _test_single_value(tc, tt, vals, op):
|
||||||
uops = [
|
uops = [
|
||||||
|
|
|
@ -156,10 +156,12 @@ class Compiled:
|
||||||
|
|
||||||
def to_program(self, k):
|
def to_program(self, k):
|
||||||
k.linearize()
|
k.linearize()
|
||||||
src, global_size, local_size = self.renderer(k.function_name, k.uops)
|
ret = self.renderer(k.function_name, k.uops)
|
||||||
|
src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,)
|
||||||
|
#TODO: I need to find a better way to select ARM64
|
||||||
return ASTRunner(k.function_name, src, global_size, local_size,
|
return ASTRunner(k.function_name, src, global_size, local_size,
|
||||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||||
display_name=k.display_name).build(self.runtime)
|
display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime)
|
||||||
|
|
||||||
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
||||||
# all movementops do nothing in a Compiled buffer!
|
# all movementops do nothing in a Compiled buffer!
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
|
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
|
||||||
|
from functools import partial, reduce
|
||||||
from tinygrad.ops import Compiled
|
from tinygrad.ops import Compiled
|
||||||
|
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
|
||||||
from tinygrad.runtime.lib import RawMallocBuffer
|
from tinygrad.runtime.lib import RawMallocBuffer
|
||||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ARM64 = getenv('ARM64', False)
|
||||||
|
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
|
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
|
||||||
|
@ -11,24 +18,64 @@ args = {
|
||||||
}[platform.system()]
|
}[platform.system()]
|
||||||
|
|
||||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
|
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
|
||||||
|
ADDRESS = 0x10000
|
||||||
|
|
||||||
|
# Unicorn doesn't support external calls
|
||||||
|
def align(addr): return (addr+4095) & ~(4095)
|
||||||
|
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
|
||||||
|
def emulate_ext_calls(fn, uc, address, size, user_data):
|
||||||
|
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
|
||||||
|
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
|
||||||
|
|
||||||
class ClangProgram:
|
class ClangProgram:
|
||||||
def __init__(self, name:str, prg:str):
|
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||||
prg = CLANG_PROGRAM_HEADER + prg
|
|
||||||
# TODO: is there a way to not write this to disk?
|
# TODO: is there a way to not write this to disk?
|
||||||
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
|
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
|
||||||
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
|
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
|
||||||
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
|
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
|
||||||
if not os.path.exists(fn):
|
if not os.path.exists(fn):
|
||||||
tmp = f"{fn}.{os.getpid()}.tmp"
|
tmp = f"{fn}.{os.getpid()}.tmp"
|
||||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
|
if not binary:
|
||||||
os.rename(tmp, fn)
|
prg = CLANG_PROGRAM_HEADER + prg
|
||||||
|
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||||
|
os.rename(tmp, fn)
|
||||||
|
else:
|
||||||
|
if DEBUG >= 5: print(prg)
|
||||||
|
if CI and ARM64:
|
||||||
|
prg = prg.split('\n') # type: ignore
|
||||||
|
self.varsize = align(int(prg[0].split(" ")[1]))
|
||||||
|
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
|
||||||
|
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
|
||||||
|
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||||
|
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
|
||||||
|
self.prg = open(fn + '.bin', 'rb').read()
|
||||||
|
return
|
||||||
|
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
|
||||||
|
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
|
||||||
self.lib = ctypes.CDLL(fn)
|
self.lib = ctypes.CDLL(fn)
|
||||||
self.fxn = self.lib[name]
|
self.fxn = self.lib[name]
|
||||||
|
|
||||||
def __call__(self, global_size, local_size, *args, wait=False):
|
def __call__(self, global_size, local_size, *args, wait=False):
|
||||||
if wait: st = time.monotonic()
|
if wait: st = time.monotonic()
|
||||||
self.fxn(*[x._buf for x in args])
|
if CI and ARM64:
|
||||||
|
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
|
||||||
|
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
|
||||||
|
mu.mem_map(ADDRESS, total_mem)
|
||||||
|
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
|
||||||
|
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
|
||||||
|
addr = ADDRESS + len(self.prg)
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i<=7:
|
||||||
|
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
|
||||||
|
else:
|
||||||
|
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
|
||||||
|
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
|
||||||
|
addr += arg.size * arg.dtype.itemsize
|
||||||
|
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
|
||||||
|
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
|
||||||
|
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
|
||||||
|
else:
|
||||||
|
self.fxn(*[x._buf for x in args])
|
||||||
if wait: return time.monotonic()-st
|
if wait: return time.monotonic()-st
|
||||||
|
|
||||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
||||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||||
|
|
|
@ -37,7 +37,7 @@ def unwrap(x):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
class MetalProgram:
|
class MetalProgram:
|
||||||
def __init__(self, name:str, prg:str):
|
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||||
if METAL_XCODE:
|
if METAL_XCODE:
|
||||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||||
|
|
Loading…
Reference in New Issue