diff --git a/test/test_dtype.py b/test/test_dtype.py index dc09a1966..65fde81ab 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,128 +1,94 @@ import unittest import numpy as np -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, DType, DEBUG from tinygrad.lazy import Device from tinygrad.tensor import Tensor, dtypes + +def _test_to_np(a:Tensor, np_dtype, target): + print(a) + na = a.numpy() + print(na, na.dtype, a.lazydata.realized) + assert na.dtype == np_dtype + np.testing.assert_allclose(na, target) + +def _test_op(fxn, target_dtype:DType, target): + c = fxn() + if DEBUG >= 2: print(c.numpy()) + assert c.dtype == target_dtype + np.testing.assert_allclose(c.numpy(), target) + +def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target) +def _test_add(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target) +def _test_mul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target) +def _test_matmul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target) +def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target) +def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target) +def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target) + # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for LLVM, it segfaults because it can't link to the casting function @unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends") -class TestDtype(unittest.TestCase): - def _test_to_np(self, a, np_dtype, target): - print(a) - na = a.numpy() - print(na, na.dtype, a.lazydata.realized) - assert na.dtype == np_dtype - np.testing.assert_allclose(na, target) - - def test_half_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4]) - def test_int8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4]) - def test_uint8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4]) - def test_int64_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4]) - - def _test_cast(self, a, target_dtype, target): - print(a) - b = a.cast(target_dtype) - print(b.numpy()) - assert b.dtype == target_dtype - np.testing.assert_allclose(b.numpy(), target) - - def test_float_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4]) - def test_float_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4]) - def test_float_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4]) - def test_float_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4]) - - def test_half_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4]) - def test_half_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4]) - def test_half_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4]) - def test_half_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4]) - - def test_int8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4]) - def test_int8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4]) - def test_int8_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4]) - def test_int8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4]) - - def test_uint8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4]) - def test_uint8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4]) - def test_uint8_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4]) - def test_uint8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4]) - - def _test_add(self, a, b, target_dtype, target): - c = a+b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) - - def test_half_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8]) - def test_int8_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8]) - def test_int64_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8]) +class TestHalfDtype(unittest.TestCase): + def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4]) - def _test_mul(self, a, b, target_dtype, target): - c = a*b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) + def test_half_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4]) + def test_half_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4]) + def test_half_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4]) + def test_half_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4]) - def test_half_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16]) - def test_int8_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16]) - def test_int64_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) + def test_float_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4]) + def test_int8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4]) + def test_uint8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4]) + + def test_half_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8]) + def test_half_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16]) + def test_half_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) + + def test_half_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) + def test_int8_add_upcast_half(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8]) + def test_int8_mul_upcast_half(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16]) + def test_half_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) + def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) + def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) - def _test_matmul(self, a, b, target_dtype, target): - c = a@b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) +class TestInt8Dtype(unittest.TestCase): + def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4]) + def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4]) + def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4]) - def test_half_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) - def test_int8_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]]) - def test_int64_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) + def test_float_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4]) + def test_float_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4]) + def test_float_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4]) - def _test_add_upcast(self, a, b, target_dtype, target): - c = a+b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) + def test_int8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4]) + def test_int8_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4]) + def test_int8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4]) - def test_half_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) - def test_int8_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) - def test_int8_add_upcast_half(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8]) - def test_int8_add_upcast_int64(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8]) + def test_uint8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4]) + def test_uint8_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4]) + def test_uint8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4]) - def _test_mul_upcast(self, a, b, target_dtype, target): - c = a*b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) + def test_int8_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8]) + def test_int64_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8]) - def test_half_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) - def test_int8_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) - def test_int8_mul_upcast_half(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16]) - def test_int8_mul_upcast_int64(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) + def test_int8_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16]) + def test_int64_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) - def _test_matmul_upcast(self, a, b, target_dtype, target): - c = a@b - print(c.numpy()) - assert c.dtype == target_dtype - np.testing.assert_allclose(c.numpy(), target) + def test_int8_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]]) + def test_int64_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) - def test_half_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) - def test_int8_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) - def test_int8_matmul_upcast_half(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) - def test_int8_matmul_upcast_int64(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) + def test_int8_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) + def test_int8_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) + def test_int8_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) - def test_int8_to_uint8_negative(self): - a = Tensor([-1, -2, -3, -4], dtype=dtypes.int8) - print(a) - b = a.cast(dtypes.uint8) - print(b.numpy()) - np.testing.assert_allclose(b.numpy(), [255, 254, 253, 252]) + def test_int8_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8]) + def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) + def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) - def test_uint8_to_int8_overflow(self): - a = Tensor([255, 254, 253, 252], dtype=dtypes.uint8) - print(a) - b = a.cast(dtypes.int8) - print(b.numpy()) - np.testing.assert_allclose(b.numpy(), [-1, -2, -3, -4]) + def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + + def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index b2372fd02..bb028b220 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -38,7 +38,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: module = ir.Module(name=__file__) # create llvm function - func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1)}[buf.dtype] for buf in bufs] + func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64)}[buf.dtype] for buf in bufs] func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') # force llvmlite to allow us to add function attribute then add the attribute @@ -88,13 +88,21 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) else: val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True)) - if func_dtypes[args.i] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType()) + if func_dtypes[args.i] != ir.FloatType(): + if dtypes.is_int(bufs[args.i].dtype): + val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType()) + else: + val = bb[-1].fpext(val, ir.FloatType()) lvars[newvar] = val if uop == UOps.STORE: assert args.valid.min == 1, "store must be valid" idx = args.idx.render(render_llvm, bb[-1]) element = lvars[vin[0]] - if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0]) + if func_dtypes[0] != ir.FloatType(): + if dtypes.is_int(bufs[args.i].dtype): + element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0]) + else: + element = bb[-1].fptrunc(element, func_dtypes[0]) bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True)) if uop == UOps.ALU: lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index ce1145456..dafd4081c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -75,7 +75,9 @@ class dtypes: int64: Final[DType] = DType(2, 8, "int64", np.int64) uint8: Final[DType] = DType(0, 1, "uchar", np.uint8) @staticmethod - def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) + def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) + @staticmethod + def is_unsigned(x: DType): return x in (dtypes.uint8) @staticmethod def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]