1
0
Fork 0

Non fp32 math (#2264)

* `global_load` and `global_store` using buffer dtype

* `UOps.PHI` in all dtypes

* `UOps.ALU` in all dtypes

* `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes

* -- endof implementation --
+tiny lint changes

* these tests require the fp16 extention

you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261)

`GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul`

skip the new test_linearizer_failures in CI GPU because of the fp16 extention

This passes on a real GPU since the extention is available:
`GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8`

see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644)

* these tests fail in CI due to segfaults and CPU crashes

To confirm they're green locally, you can run the following commands:

1. For the tests skipped in test_ops.py (note: CLANG is very slow)

`for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done`

2. For the ONNX tests skipped in CLANG:

```
CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu
```

3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific

`LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu`

* Revert "these tests fail in CI due to segfaults and CPU crashes"

This reverts commit 15db570143.

* merge with cleanup-vectorized-hip-renders

* barely working HIP P1, ALU ops need a refactor?

* manage the fact that in HIP [half2 is actually an unsigned int vec](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L59)) and half is a totally different __half that [has an unsigned int element in it](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L50)) but can't be accessed [because it's private](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L86)). If you just do this:

```
half2 val0 = // ...
half val1 = // ...
```
then you can't do:
```
val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half'))
```

* update the sign definition to avoid division by zero in all dtypes

* diff cleanup p1: why were these in the diff anyways

* less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI!

add ALU ops overloads for HIP

this will make HIP max work

handle mod

Revert "handle mod"

This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933.

update max to use hmax

add HIP GEP render logic

enable CIFAR fp16 benchmark

test ops for HIP

back to store as float because this only works for float4 grouping right now

test_ops for hip!!

always sign

* back to the sign we had before because we cant do a backward pass on a Less node

* remove old hacks

HIP compiling test_ops in CI takes ~9 mins, not doing it for now

new HIP ALUs

* reduce accs done right

* refactor to function

* no device hacks

hacks p2

the other way

* LLVM ALU ops

half, float and double are all float

update max

* update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool

* cleanup LLVM wrong code

* dummy change for the CUDA install glitch

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
pull/2593/head
qazal 2023-12-03 16:45:49 -05:00 committed by GitHub
parent 1ac958a058
commit 4380ccb169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 74 additions and 46 deletions

View File

@ -128,8 +128,8 @@ jobs:
run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w winograd
run: WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
#- name: Run 10 CIFAR training steps w WINO/HALF/HIP
# run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt
- name: Run 10 CIFAR training steps w WINO/HALF/HIP
run: HALF=1 HIP=1 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino_half_hip.txt
- uses: actions/upload-artifact@v3
with:
name: Speed (AMD)

View File

@ -29,15 +29,15 @@ def get_max(var):
def remove_single_scalar_curly_braces(ptx_code):
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
def render_const(args):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else str(args))
def render_const(args,dtype:DType):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
def render_cast(x:str, dtype:DType):
return f"{x}.to({triton_dtypes[dtype]})"
def define_scalar(local_size, dtype, args):
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args)}, dtype={triton_dtypes[dtype]})"
return render_const(args)
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
return render_const(args,dtype)
def uops_to_triton(function_name:str, uops:List[UOp]):
local_size: List[int] = []

View File

@ -177,8 +177,8 @@ if Device.DEFAULT in ['GPU', 'METAL']:
backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy
backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double
# Segfaults in CI
if Device.DEFAULT in ['LLVM', 'CUDA'] and CI:
# Segfaults in CI, GPU requires cl_khr_fp16
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')

View File

@ -51,7 +51,7 @@ class TestRealWorld(unittest.TestCase):
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
def test_llama(self):
Tensor.default_type = dtypes.float16
@ -63,7 +63,7 @@ class TestRealWorld(unittest.TestCase):
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):
Tensor.default_type = dtypes.float16
@ -102,4 +102,4 @@ class TestRealWorld(unittest.TestCase):
#Device.DEFAULT = old_default
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -15,7 +15,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time."
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU"], "Not working on LLVM, slow on others")
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16")
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):

View File

@ -22,7 +22,6 @@ class TestHIPCompilationRDNA(unittest.TestCase):
output = model(input)
output.numpy()
@unittest.expectedFailure
def test_compile_hip_speedyresnet_hf(self):
Tensor.default_type = dtypes.float16
@ -34,4 +33,4 @@ class TestHIPCompilationRDNA(unittest.TestCase):
output.numpy()
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -79,7 +79,7 @@ class TestLinearizerFailures(unittest.TestCase):
ast = helper_add_store(ast)
helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"])
@unittest.skipIf(Device.DEFAULT=="LLVM" and not OSX, "Segmentation fault on ubuntu")
@unittest.skipIf((Device.DEFAULT=="LLVM" and not OSX) or (Device.DEFAULT == "GPU" and CI), "Segmentation fault on ubuntu, GPU requires cl_khr_fp16")
def test_failure_8(self):
ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None)
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]

View File

@ -20,7 +20,7 @@ class TestSpecific(unittest.TestCase):
w = Tensor.randn(2048, 512)
(x @ w).reshape(1, 128, 4).contiguous().realize()
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "CUDA"], "Broken on LLVM, WEBGPU and CUDA")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], "Broken on LLVM and webgpu, GPU requires cl_khr_fp16")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
@ -52,4 +52,4 @@ class TestSpecific(unittest.TestCase):
x.conv2d(w, stride=2, padding=1).permute(0,2,3,1).reshape(18, 18*384//4, 4).contiguous().realize()
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -14,7 +14,7 @@ def _uops_to_prg(uops):
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))
uops.append(UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg))
return uops[-1]
def _test_single_value(vals, op, dtype):
@ -43,7 +43,7 @@ def _test_single_value_const(vals, op, dtype):
class TestUOps(unittest.TestCase):
def _equal(self, v1, v2):
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]:
@ -78,7 +78,7 @@ class TestFloatUOps(TestUOps):
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
# MOD isn't tested on floats
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)

View File

@ -46,6 +46,11 @@ class Linearizer(Kernel):
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
def get_reduce_acc(self, op, dtype:DType):
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
@ -74,7 +79,8 @@ class Linearizer(Kernel):
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
else:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
localtype = dtypes.float32 if amt == 1 else dtypes.float.vec(amt)
localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt)
if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt)
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
@ -237,7 +243,7 @@ class Linearizer(Kernel):
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
if self.tensor_core:
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
@ -343,7 +349,7 @@ class Linearizer(Kernel):
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype))
# late reduce loop
loop_ctx = render_loop(end_local_idxs)
@ -455,8 +461,14 @@ class Linearizer(Kernel):
self.applied_opts_cache = self.applied_opts[:]
return self
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
key = (uop, dtype, vin, arg)
if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
if uop == UOps.ALU: # upcast vins to the same dtype
upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float
if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool
else: vin = tuple(self.cast(x, upcast_dtype) for x in vin)
dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool
if simplify:
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
@ -502,11 +514,11 @@ class Linearizer(Kernel):
ret: List[UOp] = []
input_acc = acc[:]
for val, off in zip(zip(*values), cast(List[int], offs)):
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op])
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uop(UOps.PHI, dtypes.float32, (input_acc[off], acc[off]) + tuple(loop_ctx))
acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
else:
ret = [self.uop(UOps.ALU, dtypes.float32, val, x.op) for val in zip(*values)]
ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)]
return ret

View File

@ -47,18 +47,18 @@ class CStyleLanguage(NamedTuple):
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})"
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
def render_const(self, x:Union[float,int,bool], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype == dtypes.float16:
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
if output_dtype.sz > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
@ -95,7 +95,7 @@ class CStyleLanguage(NamedTuple):
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes.float.vec(4), "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16:
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
@ -156,8 +156,6 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
# remove parens if ALU types are the same. TODO: can do more here
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
elif args == BinaryOps.MAX:
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin] + [dtype])
else:
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
assert child_count[u] != 0, f"childless ALU op found {u}"
@ -292,10 +290,31 @@ __device__ float4 vload_half4(size_t offset, const half *p) { return make_float4
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
__device__ half exp2(half x) { return hexp2(x); }
__device__ half log2(half x) { return hlog2(x); }
__device__ half sin(half x) { return hsin(x); }
__device__ half sqrt(half x) { return hsqrt(x); }
__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; }
__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); }
__device__ bool operator!=(const half &a, const int &b) { return (float)a != b; }
// HACKS for ALU ops on half and result of half2 GEP
__device__ half operator+(const half &a, const unsigned short &b) { return __hadd(a, (half)(b)); }
__device__ half operator-(const half &a, const unsigned short &b) { return __hsub(a, (half)(b)); }
__device__ half operator*(const half &a, const unsigned short &b) { return __hmul(a, (half)(b)); }
__device__ half operator/(const half &a, const unsigned short &b) { return __hdiv(a, (half)(b)); }
__device__ bool operator<(const half &a, const unsigned short &b) { return __hlt(a, (half)(b)); }
// now the other way
__device__ half operator+(const unsigned short &a, const half &b) { return __hadd((half)(a), b); }
__device__ half operator-(const unsigned short &a, const half &b) { return __hsub((half)(a), b); }
__device__ half operator*(const unsigned short &a, const half &b) { return __hmul((half)(a), b); }
__device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); }
__device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); }
"""
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"}
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
# TODO: how much of this can be merged with above?
@ -338,9 +357,6 @@ class WGSLLanguage(CStyleLanguage):
if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())

View File

@ -6,8 +6,9 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.not_(x),
UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
@ -16,12 +17,12 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
# TODO: this should be casted
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.select(builder.icmp_signed(">", x, y), x, y),
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y) if isinstance(x.type, ir.FloatType) else builder.urem(x,y),
BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y),
BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y),
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z
),
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
@ -98,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(ir.FloatType())
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2]._block)
phis.append((rp, lvars[rp]))
@ -146,7 +147,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
else: store_op()
if uop == UOps.ALU:
lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype)
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)
bb[-1].ret_void()

View File

@ -634,7 +634,7 @@ class Tensor:
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return self / (self.abs() + 1e-10)
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self): return 1.0/self
# ***** activation functions (unary) *****