1
0
Fork 0

for issue #1555, int64 and int8 in CI=1 ARM64=1 CLANG=1 (#1572)

* fixed for int8,int64, added dtype broadcasting test, passing all CI,ARM64,CLANG tests

* remove shifts
This commit is contained in:
corranr 2023-08-19 05:40:13 +01:00 committed by GitHub
parent ae39cf84ab
commit 68ebbd2954
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View file

@ -48,6 +48,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType):
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):

View file

@ -139,12 +139,11 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
elif uop == UOps.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD: