1
0
Fork 0

fix usage of arm64 regs according to CC (#1570)

This commit is contained in:
nimlgen 2023-08-19 07:40:32 +03:00 committed by GitHub
parent 68ebbd2954
commit faa521bcab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,8 +20,8 @@ def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop:Optional[UOps] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
x_regs = ['x' + str(i) for i in reversed(range(12))]
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
@ -58,12 +58,12 @@ def specialize_to_arm64(fn_nm, asm):
if len(available_regs) == 0:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11')
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x11', 'x12', 'x13']
temp_ints = ['x12', 'x13', 'x16']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
@ -83,7 +83,7 @@ def specialize_to_arm64(fn_nm, asm):
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
@ -161,7 +161,7 @@ def specialize_to_arm64(fn_nm, asm):
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()