From 2cc1d970c699a9757a921e9025f1eb5cdcd43c1e Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 7 Nov 2022 21:12:08 -0800 Subject: [PATCH] updates from the chonker branch --- extra/gemm/.gitignore | 2 + test/test_ops.py | 5 -- test/test_speed_v_torch.py | 180 +++++++++++++++++++++++-------------- tinygrad/helpers.py | 1 + tinygrad/tensor.py | 5 +- 5 files changed, 121 insertions(+), 72 deletions(-) create mode 100644 extra/gemm/.gitignore diff --git a/extra/gemm/.gitignore b/extra/gemm/.gitignore new file mode 100644 index 000000000..7bdea85fd --- /dev/null +++ b/extra/gemm/.gitignore @@ -0,0 +1,2 @@ +*.s +*.ll diff --git a/test/test_ops.py b/test/test_ops.py index fb25d5e9b..086cfb307 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -249,11 +249,6 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - def test_simple_conv2d_forward(self): - helper_test_op([(2,10,9,9), (10,10,3,3)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True) - # expect reduce nodes == 3 def test_simple_conv2d_nhwc(self): # weights (from tf): filter_height x filter_width x in_channels x out_channels diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 73102b69e..9aba18fe7 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -6,86 +6,134 @@ import time import numpy as np from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d +from termcolor import colored +from tinygrad.llops.ops_gpu import CL -IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "1,16,64").split(",")] +IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")] + +def colorize_float(x): + ret = f"{x:7.2f}x" + if x < 0.8: + return colored(ret, 'green') + elif x > 1.5: + return colored(ret, 'red') + else: + return colored(ret, 'yellow') + +CNT = 8 +def test_speed(f1, *args): + ets = [] + ret = None + for _ in range(CNT): + del ret + st = time.monotonic() + ret = f1(*args) + if ret.device in ["GPU", "OPENCL"]: + CL.cl_queue.finish() + et = (time.monotonic() - st) * 1000 + ets.append(et) + return ret.numpy(), np.min(ets) + +def test_generic_square(name, N, f1, f2): + torch.manual_seed(0) + torch_a = torch.rand(N, N) - 0.5 + torch_b = torch.rand(N, N) - 0.5 + tiny_a = Tensor(torch_a.cpu().numpy()) + tiny_b = Tensor(torch_b.cpu().numpy()) + + with torch.no_grad(): + val_torch, et_torch = test_speed(f1, torch_a, torch_b) + val_tinygrad, et_tinygrad = test_speed(lambda *args: f2(*args).realize(), tiny_a, tiny_b) + + print(f"{name:30s} {N:4d}x{N:4d} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower", val_torch.sum(), val_tinygrad.sum()) + np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3) -CNT = 5 class TestSpeed(unittest.TestCase): + def test_sum(self): + def f(a, b): return a.sum() + test_generic_square('sum', 4096, f, f) + + def test_permute(self): + # this is a 64MB tensor, M1 L1 cache is 128kB + # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size + def f1(a, b): return a.permute(1,0).contiguous() + # NOTE: this isn't being constant folded + def f2(a, b): return a.permute(1,0) + 0 + test_generic_square('permute', 4096, f1, f2) + + def test_neg(self): + def f(a, b): return -a + test_generic_square('neg', 4096, f, f) + + def test_exp(self): + def f(a, b): return a.exp() + test_generic_square('exp', 2048, f, f) + + def test_relu(self): + def f(a, b): return a.relu() + test_generic_square('relu', 4096, f, f) + + def test_max(self): + def f(a, b): return a.max() + test_generic_square('max', 4096, f, f) + + def test_mul_sum(self): + def f(a, b): return (a*b).sum() + test_generic_square('mul_sum', 4096, f, f) + + def test_add(self): + for N in [1024, 4096]: + def f(a, b): return a + b + test_generic_square('add', N, f, f) + + def test_add_sq(self): + def f(a, b): return a*a + b*b + test_generic_square('add_sq', 4096, f, f) + def test_gemm(self): - N = 1024 - torch.manual_seed(0) - torch_a = torch.rand(N, N) - torch_b = torch.rand(N, N) - tiny_a = Tensor(torch_a.cpu().numpy()) - tiny_b = Tensor(torch_b.cpu().numpy()) + def f(a, b): return a @ b + test_generic_square('gemm', 512, f, f) - ets_torch = [] - for _ in range(CNT): - torch_a += 1 - st = time.monotonic() - torch_c = torch_a @ torch_b - et_torch = (time.monotonic() - st) * 1000 - ets_torch.append(et_torch) + def test_gemm_unrolled(self): + N = 512 + def f1(a, b): return a@b.T + def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2) + test_generic_square('gemm_unrolled', N, f1, f2) - ets_tinygrad = [] - for _ in range(CNT): - tiny_a += 1 - tiny_a.realize() - st = time.monotonic() - tiny_c = tiny_a @ tiny_b - tiny_c.realize() - et_tinygrad = (time.monotonic() - st) * 1000 - ets_tinygrad.append(et_tinygrad) + def test_gemm_unrolled_permute_r(self): + N = 512 + def f1(a, b): return a@b + def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) + test_generic_square('gemm_unrolled_permute_r', N, f1, f2) - val_torch = torch_c.numpy().sum() - val_tinygrad = tiny_c.numpy().sum() - - et_torch = np.median(ets_torch) - et_tinygrad = np.median(ets_tinygrad) - print(f"{N}x{N} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad) - relative_error = abs((val_tinygrad-val_torch)/val_torch) - assert relative_error < 0.01 + def test_gemm_unrolled_permute_lr(self): + N = 512 + def f1(a, b): return a.T@b + def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) + test_generic_square('gemm_unrolled_permute_lr', N, f1, f2) def test_conv2d(self): torch.manual_seed(0) for bs in [32]: for in_chans in IN_CHANS: - for out_chans in [64]: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - img_size = 64 if device == 'cuda' else 32 - src = torch.rand(bs, in_chans, img_size, img_size) - dat = src.clone().to(device) - src_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None) - conv = src_conv.to(device) + for out_chans in [32]: + img_size = 34 + torch_dat = torch.rand(bs, in_chans, img_size, img_size) + torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None) + + tiny_dat = Tensor(torch_dat.cpu().numpy()) + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None) + tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + + def f1(): return torch_conv(torch_dat) + def f2(): return tiny_conv(tiny_dat).realize() + with torch.no_grad(): - val_torch = conv(dat).cpu().numpy().sum() - ets_torch = [] - for _ in range(CNT): - dat += 1 - st = time.monotonic() - val_torch = conv(dat).cpu().numpy().sum() - et_torch = (time.monotonic() - st) * 1000 - ets_torch.append(et_torch) + val_torch, et_torch = test_speed(f1) + val_tinygrad, et_tinygrad = test_speed(f2) - Tensor.no_grad = False - dat = Tensor(src.numpy()) - conv = Conv2d(in_chans, out_chans, 3, bias=None) - conv.weight = Tensor(src_conv.weight.detach().cpu().numpy()) - val_tinygrad = conv(dat).numpy().sum() - ets_tinygrad = [] - for _ in range(CNT): - dat += 1 - dat.realize() - st = time.monotonic() - val_tinygrad = conv(dat).numpy().sum() - et_tinygrad = (time.monotonic() - st) * 1000 - ets_tinygrad.append(et_tinygrad) - - et_torch = np.median(ets_torch) - et_tinygrad = np.median(ets_tinygrad) - print(f"bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} {et_torch:7.2f} ms in torch({device}), {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad) - relative_error = abs((val_tinygrad-val_torch)/val_torch) - assert relative_error < 0.01 + print(f"bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower", val_torch.sum(), val_tinygrad.sum()) + np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index b92be7323..91c8da66c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -5,6 +5,7 @@ def dedup(x): return list(dict.fromkeys(x)) # retains list order def prod(x): return math.prod(x) def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x) def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python +def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape))) def shape_to_axis(old_shape, new_shape): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 56559aff4..b3dff6a02 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -87,6 +87,9 @@ class Tensor: # TODO: remove use of numpy here + @classmethod + def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs) + @classmethod def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs) @@ -301,7 +304,7 @@ class Tensor: # TODO: fix the kwargs problem, then remove these (or not, since they now fix tuples) def reshape(self, shape, *args): return self._reshape(shape=argfix(shape, *args)) - def expand(self, shape, *args): return self._expand(shape=argfix(shape, *args)) + def expand(self, shape, *args): return self._expand(shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) def permute(self, order, *args): return self._permute(order=argfix(order, *args)) def linear(self, weight:Tensor, bias:Optional[Tensor]=None):