* fixed for int8,int64, added dtype broadcasting test, passing all CI,ARM64,CLANG tests * remove shifts
This commit is contained in:
parent
ae39cf84ab
commit
68ebbd2954
|
@ -48,6 +48,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType):
|
|||
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
|
||||
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
|
||||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
|
|
|
@ -139,12 +139,11 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
|
|
Loading…
Reference in a new issue