274 lines
11 KiB
Python
274 lines
11 KiB
Python
import struct
|
|
from platform import system
|
|
from typing import Tuple, Dict, List, Optional
|
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
|
from tinygrad.codegen.linearizer import UOps, UOp
|
|
from tinygrad.helpers import dtypes, CI
|
|
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
|
|
|
|
|
def float_to_hex(x):
|
|
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
|
|
|
|
|
def compute_offsets(total):
|
|
quotient, remainder = divmod(total, 4096)
|
|
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
|
|
|
|
|
|
# NOTE: Darwin needs names to start with a "_"
|
|
def get_name(name):
|
|
return ("_" if system() == "Darwin" else "") + name
|
|
|
|
|
|
class ARM64Language(AssemblyLanguage):
|
|
pass
|
|
|
|
|
|
def specialize_to_arm64(fn_nm, asm):
|
|
var_size = 16
|
|
prev_uop: Optional[UOps] = None
|
|
ins = []
|
|
x_regs = ["x" + str(i) for i in reversed(range(12))]
|
|
s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16]
|
|
type_to_reg = {
|
|
dtypes.double: "d",
|
|
dtypes.half: "h",
|
|
dtypes.float32: "s",
|
|
dtypes.bool: "w",
|
|
dtypes.int8: "w",
|
|
dtypes.int32: "w",
|
|
dtypes.int64: "x",
|
|
dtypes.uint8: "w",
|
|
dtypes.uint32: "w",
|
|
dtypes.uint64: "x",
|
|
}
|
|
alu = {
|
|
BinaryOps.ADD: "add",
|
|
BinaryOps.SUB: "sub",
|
|
BinaryOps.MUL: "mul",
|
|
BinaryOps.DIV: "div",
|
|
BinaryOps.MAX: "max",
|
|
BinaryOps.MOD: "",
|
|
BinaryOps.CMPLT: "subs",
|
|
UnaryOps.NOOP: "mov",
|
|
UnaryOps.NEG: "neg",
|
|
UnaryOps.SIN: "bl " + get_name("sinf"),
|
|
UnaryOps.LOG2: "bl " + get_name("log2f"),
|
|
UnaryOps.EXP2: "bl " + get_name("exp2f"),
|
|
UnaryOps.SQRT: "bl " + get_name("sqrtf"),
|
|
TernaryOps.MULACC: "madd",
|
|
TernaryOps.WHERE: "fcsel",
|
|
}
|
|
|
|
def mov_imm(value, reg):
|
|
# Manually move value into reg if value can't fit
|
|
if value.__class__ is not float and abs(value) > abs(65535):
|
|
ins.append(f"movz w15, #{value & 0xffff}")
|
|
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
|
ins.append(f"sxtw {reg}, w15")
|
|
elif reg[0] == "s":
|
|
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
|
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
|
ins.append("str x15, [sp, 16]")
|
|
ins.append(f"ldr {reg}, [sp, 16]")
|
|
else:
|
|
ins.append(f"mov {reg}, #{value}")
|
|
|
|
# Get variables intervals
|
|
live_range: Dict[str, List[int]] = {}
|
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
|
for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]:
|
|
live_range[var.nm] = (
|
|
[i, i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
|
)
|
|
|
|
mem_vars: Dict[str, int] = {}
|
|
rtor: Dict[str, str] = {}
|
|
|
|
def allocate_regs(mvars):
|
|
nonlocal var_size
|
|
for v in [
|
|
v
|
|
for v in mvars
|
|
if v is not None and v.__class__ is not int and v.nm not in rtor
|
|
]:
|
|
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
|
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
|
if not available_regs:
|
|
# ARM needs the stack 16-byte aligned
|
|
var_size += 16
|
|
available_regs.append("s0" if dtypes.is_float(out[1]) else "x12")
|
|
mem_vars[v.nm] = var_size
|
|
rtor[v.nm] = available_regs.pop()
|
|
|
|
temp_floats = ["s0", "s1", "s2"]
|
|
temp_ints = ["x12", "x13", "x16"]
|
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
|
# Clear regs out of interval
|
|
for var, reg in list(rtor.items()):
|
|
available_regs = s_regs if reg[0] == "s" else x_regs
|
|
if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]:
|
|
available_regs.append(rtor.pop(var))
|
|
# Assign a registers to the variables using live ranges.
|
|
allocate_regs([out] + vin)
|
|
# Assign temp regs to vin and load them before direct use
|
|
for i, v in enumerate(
|
|
[v for v in vin if v.__class__ is not int and v.nm in mem_vars]
|
|
):
|
|
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
|
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
|
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
|
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
|
|
|
if uop == UOps.SPECIAL:
|
|
if arg.startswith("data"):
|
|
# data 8 to n into the stack
|
|
if int(arg[4:]) >= 8:
|
|
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
|
ins.append(f"mov {rtor[out.nm]}, x15")
|
|
else:
|
|
ins.append(f"mov {rtor[out.nm]}, #0")
|
|
ins.append(f"loop_{arg}:")
|
|
elif uop == UOps.CAST:
|
|
if arg == BinaryOps.CMPLT:
|
|
if rtor[out.nm][0] == "s":
|
|
mov_imm(0.0, "s0")
|
|
mov_imm(1.0, "s1")
|
|
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
|
if rtor[out.nm][0] == "x":
|
|
mov_imm(0, "x14")
|
|
mov_imm(1, "x15")
|
|
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
|
else:
|
|
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
|
elif uop == UOps.ALU:
|
|
if len(vin) == 2 and vin[1].__class__ is int:
|
|
mov_imm(vin[1], "x15")
|
|
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
|
ins.append(
|
|
f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
|
)
|
|
elif arg == TernaryOps.WHERE:
|
|
ins.append(
|
|
f"fcmp {rtor[vin[0].nm]}, #0.0"
|
|
if rtor[vin[0].nm][0] == "s"
|
|
else f"cmp {rtor[vin[0].nm]}, #0"
|
|
)
|
|
ins.append(
|
|
f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne"
|
|
)
|
|
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
|
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
|
if CI:
|
|
ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
|
else:
|
|
save_regs = [
|
|
k for k in rtor.keys() if k != out.nm and k not in mem_vars
|
|
]
|
|
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
|
# Save the registers before they are cleared by func call
|
|
for i, k in enumerate(save_regs, 1):
|
|
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
|
ins.append("stp x29, x30, [sp, #0]!")
|
|
ins.append("mov x29, sp")
|
|
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
|
ins.append(alu[arg])
|
|
ins.append(f"fmov {rtor[out.nm]}, s0")
|
|
ins.append("mov sp, x29")
|
|
ins.append("ldp x29, x30, [sp], #0")
|
|
for i, k in enumerate(save_regs, 1):
|
|
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
|
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
|
elif arg == BinaryOps.CMPLT:
|
|
ins.append(
|
|
f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
|
if not dtypes.is_float(vin[0][1])
|
|
else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}"
|
|
)
|
|
elif arg == BinaryOps.MOD:
|
|
rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm]
|
|
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
|
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
|
else:
|
|
ins.append(
|
|
f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
|
)
|
|
elif uop == UOps.LOAD:
|
|
if arg.__class__ in (int, float):
|
|
mov_imm(arg, rtor[out.nm])
|
|
else:
|
|
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
|
reg_in = (
|
|
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
|
if arg[2] is not None
|
|
else rtor[out.nm]
|
|
)
|
|
mov_imm(arg[0], "x15")
|
|
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
|
ins.append(
|
|
f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]"
|
|
)
|
|
if arg[2] is not None:
|
|
ins.append(
|
|
f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}"
|
|
)
|
|
elif uop == UOps.STORE:
|
|
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
|
reg_out = (
|
|
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
|
if arg[2] is not None
|
|
else rtor[vin[1].nm]
|
|
)
|
|
if arg[2] is not None:
|
|
ins.append(
|
|
f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}"
|
|
)
|
|
ins.append(f"mov x15, #{arg[0]}")
|
|
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
|
elif uop == UOps.COND_BRANCH:
|
|
# TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
|
if prev_uop == UOps.LOAD:
|
|
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
|
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
|
elif uop == UOps.LABEL:
|
|
ins.append(f"{arg[1:]}:")
|
|
elif uop == UOps.ENDLOOP:
|
|
mov_imm(arg[0], "x15")
|
|
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
|
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
|
ins.append(f"b.lt loop_{arg[1]}")
|
|
prev_uop = uop
|
|
# store regs into memory if needed
|
|
if out is not None and out.nm in mem_vars:
|
|
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
|
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
|
return "\n".join(
|
|
[
|
|
f"//varsize {var_size}",
|
|
".arch armv8-a",
|
|
".text",
|
|
f".global {get_name(fn_nm)}",
|
|
".p2align 2",
|
|
f"{get_name(fn_nm)}:",
|
|
"mov x17, sp",
|
|
]
|
|
+ [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
|
+ ins
|
|
+ [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
|
+ ["ret", "\n"]
|
|
)
|
|
|
|
|
|
def uops_to_arm64_asm(
|
|
fn_nm: str, uops: List[UOp]
|
|
) -> Tuple[str, List[int], List[int], bool]:
|
|
lang = ARM64Language()
|
|
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
|
return (
|
|
specialize_to_arm64(fn_nm, lang.ins),
|
|
global_size[::-1],
|
|
local_size[::-1],
|
|
True,
|
|
)
|