428 lines
15 KiB
Python
428 lines
15 KiB
Python
from typing import (
|
|
Tuple,
|
|
List,
|
|
NamedTuple,
|
|
Any,
|
|
Dict,
|
|
Optional,
|
|
Union,
|
|
DefaultDict,
|
|
cast,
|
|
)
|
|
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
|
from tinygrad.ops import BinaryOps, UnaryOps
|
|
from tinygrad.helpers import DType, dtypes, DEBUG
|
|
from tinygrad.shape.symbolic import (
|
|
Variable,
|
|
NumNode,
|
|
MulNode,
|
|
DivNode,
|
|
ModNode,
|
|
LtNode,
|
|
SumNode,
|
|
AndNode,
|
|
)
|
|
import functools
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
_type_to_letter = {
|
|
dtypes.float32: "f",
|
|
dtypes.bool: "p",
|
|
dtypes.int32: "i",
|
|
dtypes.int64: "a",
|
|
dtypes.uint32: "u",
|
|
dtypes.uint64: "b",
|
|
dtypes.float.vec(4): "x",
|
|
dtypes.uint8: "uc",
|
|
dtypes.float16: "h",
|
|
dtypes.int8: "c",
|
|
dtypes.uint16: "us",
|
|
dtypes.float64: "d",
|
|
}
|
|
|
|
|
|
class Register(NamedTuple):
|
|
nm: str
|
|
dtype: DType
|
|
scalar: bool
|
|
off: Optional[int] = None
|
|
|
|
def __repr__(self):
|
|
return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
|
|
|
def subregs(self):
|
|
if self.dtype == dtypes.float.vec(4):
|
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
|
return []
|
|
|
|
|
|
class AssemblyInstruction(NamedTuple):
|
|
op: UOps
|
|
out: Optional[Register]
|
|
vin: List[Union[Register, int, float]]
|
|
arg: Any = None
|
|
|
|
|
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
|
class AssemblyLanguage:
|
|
supports_load3: bool = False
|
|
sin_is_sin2pi: bool = False
|
|
no_div: bool = False
|
|
# TODO: these should be global vars
|
|
cnts: DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
|
tor: Dict[Any, Register] = {}
|
|
ins: List[AssemblyInstruction] = []
|
|
|
|
def type_to_letter(self, x):
|
|
return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
|
|
|
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
|
self.tor[tok] = ret = Register(
|
|
f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}",
|
|
dtype,
|
|
scalar,
|
|
)
|
|
if dtype == dtypes.float.vec(4):
|
|
for off in range(4):
|
|
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
|
self.cnts[(dtype, scalar)] += 1
|
|
return ret
|
|
|
|
def render_numnode(self, b) -> Register:
|
|
key = ("num", b)
|
|
if key not in self.tor:
|
|
self.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b
|
|
)
|
|
)
|
|
return self.tor[key]
|
|
|
|
def render_alu(
|
|
self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32
|
|
) -> Register:
|
|
key = (op, a, b)
|
|
if key not in self.tor:
|
|
# if not isinstance(b, Register): b = render_numnode(b)
|
|
self.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU,
|
|
self.newreg(
|
|
key,
|
|
dtype=dtype,
|
|
scalar=a.scalar and (not isinstance(b, Register) or b.scalar),
|
|
),
|
|
[a, b],
|
|
op,
|
|
)
|
|
)
|
|
return self.tor[key]
|
|
|
|
def render_cast(self, a: Register, new_dtype: DType) -> Register:
|
|
if a.dtype == new_dtype:
|
|
return a
|
|
key = (a, new_dtype)
|
|
if key not in self.tor:
|
|
self.ins.append(
|
|
AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])
|
|
)
|
|
return self.tor[key]
|
|
|
|
render_ops: Any = {
|
|
Variable: lambda self, ops, ctx: ctx.tor[self],
|
|
NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
|
MulNode: lambda self, ops, ctx: ctx.render_alu(
|
|
BinaryOps.MUL, self.a.render(ops, ctx), self.b
|
|
),
|
|
DivNode: lambda self, ops, ctx: ctx.render_alu(
|
|
BinaryOps.DIV, self.a.render(ops, ctx), self.b
|
|
),
|
|
ModNode: lambda self, ops, ctx: ctx.render_alu(
|
|
BinaryOps.MOD, self.a.render(ops, ctx), self.b
|
|
),
|
|
LtNode: lambda self, ops, ctx: ctx.render_alu(
|
|
BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool
|
|
),
|
|
SumNode: lambda self, ops, ctx: functools.reduce(
|
|
lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)),
|
|
self.nodes[1:],
|
|
self.nodes[0].render(ops, ctx),
|
|
),
|
|
AndNode: lambda self, ops, ctx: functools.reduce(
|
|
lambda a, b: ctx.render_alu(
|
|
BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool
|
|
),
|
|
self.nodes[1:],
|
|
self.nodes[0].render(ops, ctx),
|
|
),
|
|
}
|
|
|
|
def addr_w_offset(self, args):
|
|
assert isinstance(args, MemOp)
|
|
idx = args.idx * args.memory_dtype.itemsize
|
|
off = 0 # TODO: should this be None?
|
|
if isinstance(idx, SumNode):
|
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
|
if (
|
|
nums and nums[0] < 4096 and (idx - nums[0]).min >= 0
|
|
): # TODO: different for each GPU?
|
|
idx -= nums[0]
|
|
off = cast(int, nums[0])
|
|
reg = idx.render(self.render_ops, self)
|
|
if self.supports_load3:
|
|
if reg.scalar:
|
|
new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype)
|
|
self.ins.append(
|
|
AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)
|
|
)
|
|
reg = new_reg
|
|
return self.tor[args.name], reg, off
|
|
reg = self.render_alu(
|
|
BinaryOps.ADD,
|
|
self.render_cast(reg, dtypes.uint64),
|
|
self.tor[args.name],
|
|
dtype=dtypes.uint64,
|
|
)
|
|
return reg, None, off
|
|
|
|
|
|
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
|
|
# TODO: Do not use clear()
|
|
lang.ins.clear()
|
|
lang.tor.clear()
|
|
lang.cnts.clear()
|
|
buf_to_dtype = {
|
|
args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL
|
|
}
|
|
global_size, local_size = [], []
|
|
skipload_branch = 0
|
|
lang.ins += [
|
|
AssemblyInstruction(
|
|
UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf
|
|
)
|
|
for buf in buf_to_dtype
|
|
]
|
|
for u in uops:
|
|
uop, dtype, vin, args, _ = u
|
|
if uop == UOps.DEFINE_LOCAL:
|
|
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU,
|
|
lang.newreg(args[0], dtype=dtypes.uint64),
|
|
[args[0]],
|
|
UnaryOps.NOOP,
|
|
)
|
|
)
|
|
elif uop == UOps.LOOP:
|
|
if args[1] == "global":
|
|
for i, var in enumerate(args[0]):
|
|
global_size.append(var.max + 1)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.SPECIAL,
|
|
lang.newreg(var, dtype=dtypes.int32),
|
|
[],
|
|
f"gid{len(args[0])-1-i}",
|
|
)
|
|
)
|
|
elif args[1] == "local":
|
|
for i, var in enumerate(args[0]):
|
|
local_size.append(var.max + 1)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.SPECIAL,
|
|
lang.newreg(var, dtype=dtypes.int32),
|
|
[],
|
|
f"lid{len(args[0])-1-i}",
|
|
)
|
|
)
|
|
else:
|
|
for var in args[0]:
|
|
if not isinstance(
|
|
var, NumNode
|
|
): # TODO: why is this coming through?
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.LOAD,
|
|
lang.newreg(var, dtype=dtypes.int32, scalar=True),
|
|
[],
|
|
0,
|
|
)
|
|
)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.LABEL, None, [], "$loop_" + var.expr
|
|
)
|
|
)
|
|
elif uop == UOps.ENDLOOP:
|
|
if args[1] not in ["global", "local", "global+local"]:
|
|
for var in reversed(args[0]):
|
|
if not isinstance(
|
|
var, NumNode
|
|
): # TODO: why is this coming through?
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU,
|
|
lang.tor[var],
|
|
[lang.tor[var], 1],
|
|
BinaryOps.ADD,
|
|
)
|
|
)
|
|
pred = lang.render_alu(
|
|
BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool
|
|
)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.COND_BRANCH,
|
|
None,
|
|
[pred],
|
|
("$loop_" + var.expr, True),
|
|
)
|
|
)
|
|
elif args[1] == "global+local":
|
|
for i, var in enumerate(reversed(args[0])):
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ENDLOOP,
|
|
None,
|
|
[lang.tor[var]],
|
|
(var.max + 1, f"gid{i}"),
|
|
)
|
|
)
|
|
elif args[1] == "local":
|
|
for i, var in enumerate(reversed(args[0])):
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ENDLOOP,
|
|
None,
|
|
[lang.tor[var]],
|
|
(var.max + 1, f"lid{i}"),
|
|
)
|
|
)
|
|
elif uop == UOps.CAST:
|
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
|
out = lang.newreg(u, dtype)
|
|
for i, sr in enumerate(out.subregs()):
|
|
lang.ins.append(
|
|
AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)
|
|
)
|
|
elif uop == UOps.ALU:
|
|
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
|
# this is the only thing that can violate SSA
|
|
if args in [BinaryOps.CMPLT]:
|
|
pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args
|
|
)
|
|
)
|
|
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
|
elif args == BinaryOps.DIV and lang.no_div:
|
|
tmp = lang.newreg((u, "rcp"))
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP
|
|
)
|
|
)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL
|
|
)
|
|
)
|
|
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
|
tmp = lang.newreg((u, "2pi"))
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU,
|
|
tmp,
|
|
[lang.tor[vin[0]], 1 / (math.pi * 2)],
|
|
BinaryOps.MUL,
|
|
)
|
|
)
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
|
else:
|
|
lang.ins.append(
|
|
AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)
|
|
)
|
|
elif uop == UOps.DEFINE_ACC:
|
|
reg = lang.newreg(u, dtype=dtype)
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
|
elif uop == UOps.SPECIAL:
|
|
lang.tor[u] = lang.tor[args]
|
|
elif uop == UOps.CONST:
|
|
lang.ins.append(
|
|
AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)
|
|
)
|
|
elif uop == UOps.LOAD:
|
|
idx, treg, off = lang.addr_w_offset(args)
|
|
reg = lang.newreg(
|
|
u,
|
|
dtype=dtype,
|
|
scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)),
|
|
)
|
|
if args.valid.min == 0:
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
|
if args.valid.max == 1:
|
|
pred = args.valid.render(lang.render_ops, lang)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.COND_BRANCH,
|
|
None,
|
|
[pred],
|
|
(f"$skipload_{skipload_branch}", False),
|
|
)
|
|
)
|
|
if args.valid.max == 1:
|
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.LOAD,
|
|
reg,
|
|
[idx] + ([treg] if treg is not None else []),
|
|
(
|
|
off,
|
|
"global" if not args.local else "shared",
|
|
args.memory_dtype
|
|
if args.memory_dtype != dtypes.float
|
|
else None,
|
|
),
|
|
)
|
|
)
|
|
if args.valid.min == 0 and args.valid.max == 1:
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.LABEL, None, [], f"$skipload_{skipload_branch}"
|
|
)
|
|
)
|
|
skipload_branch += 1
|
|
elif uop == UOps.STORE:
|
|
if args is None:
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP
|
|
)
|
|
)
|
|
else:
|
|
idx, treg, off = lang.addr_w_offset(args)
|
|
lang.ins.append(
|
|
AssemblyInstruction(
|
|
UOps.STORE,
|
|
None,
|
|
[idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []),
|
|
(
|
|
off,
|
|
"global" if not args.local else "shared",
|
|
args.memory_dtype
|
|
if args.memory_dtype != dtypes.float
|
|
else None,
|
|
),
|
|
)
|
|
)
|
|
|
|
if DEBUG >= 4:
|
|
for tins in lang.ins:
|
|
print(tins)
|
|
return global_size, local_size
|