1
0
Fork 0

updates from the chonker branch

pull/421/head
George Hotz 2022-11-07 21:12:08 -08:00
parent d878065ece
commit 2cc1d970c6
5 changed files with 121 additions and 72 deletions

2
extra/gemm/.gitignore vendored 100644
View File

@ -0,0 +1,2 @@
*.s
*.ll

View File

@ -249,11 +249,6 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_simple_conv2d_forward(self):
helper_test_op([(2,10,9,9), (10,10,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True)
# expect reduce nodes == 3
def test_simple_conv2d_nhwc(self):
# weights (from tf): filter_height x filter_width x in_channels x out_channels

View File

@ -6,86 +6,134 @@ import time
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from termcolor import colored
from tinygrad.llops.ops_gpu import CL
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "1,16,64").split(",")]
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")]
def colorize_float(x):
ret = f"{x:7.2f}x"
if x < 0.8:
return colored(ret, 'green')
elif x > 1.5:
return colored(ret, 'red')
else:
return colored(ret, 'yellow')
CNT = 8
def test_speed(f1, *args):
ets = []
ret = None
for _ in range(CNT):
del ret
st = time.monotonic()
ret = f1(*args)
if ret.device in ["GPU", "OPENCL"]:
CL.cl_queue.finish()
et = (time.monotonic() - st) * 1000
ets.append(et)
return ret.numpy(), np.min(ets)
def test_generic_square(name, N, f1, f2):
torch.manual_seed(0)
torch_a = torch.rand(N, N) - 0.5
torch_b = torch.rand(N, N) - 0.5
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy())
with torch.no_grad():
val_torch, et_torch = test_speed(f1, torch_a, torch_b)
val_tinygrad, et_tinygrad = test_speed(lambda *args: f2(*args).realize(), tiny_a, tiny_b)
print(f"{name:30s} {N:4d}x{N:4d} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower", val_torch.sum(), val_tinygrad.sum())
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)
CNT = 5
class TestSpeed(unittest.TestCase):
def test_sum(self):
def f(a, b): return a.sum()
test_generic_square('sum', 4096, f, f)
def test_permute(self):
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f1(a, b): return a.permute(1,0).contiguous()
# NOTE: this isn't being constant folded
def f2(a, b): return a.permute(1,0) + 0
test_generic_square('permute', 4096, f1, f2)
def test_neg(self):
def f(a, b): return -a
test_generic_square('neg', 4096, f, f)
def test_exp(self):
def f(a, b): return a.exp()
test_generic_square('exp', 2048, f, f)
def test_relu(self):
def f(a, b): return a.relu()
test_generic_square('relu', 4096, f, f)
def test_max(self):
def f(a, b): return a.max()
test_generic_square('max', 4096, f, f)
def test_mul_sum(self):
def f(a, b): return (a*b).sum()
test_generic_square('mul_sum', 4096, f, f)
def test_add(self):
for N in [1024, 4096]:
def f(a, b): return a + b
test_generic_square('add', N, f, f)
def test_add_sq(self):
def f(a, b): return a*a + b*b
test_generic_square('add_sq', 4096, f, f)
def test_gemm(self):
N = 1024
torch.manual_seed(0)
torch_a = torch.rand(N, N)
torch_b = torch.rand(N, N)
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy())
def f(a, b): return a @ b
test_generic_square('gemm', 512, f, f)
ets_torch = []
for _ in range(CNT):
torch_a += 1
st = time.monotonic()
torch_c = torch_a @ torch_b
et_torch = (time.monotonic() - st) * 1000
ets_torch.append(et_torch)
def test_gemm_unrolled(self):
N = 512
def f1(a, b): return a@b.T
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
test_generic_square('gemm_unrolled', N, f1, f2)
ets_tinygrad = []
for _ in range(CNT):
tiny_a += 1
tiny_a.realize()
st = time.monotonic()
tiny_c = tiny_a @ tiny_b
tiny_c.realize()
et_tinygrad = (time.monotonic() - st) * 1000
ets_tinygrad.append(et_tinygrad)
def test_gemm_unrolled_permute_r(self):
N = 512
def f1(a, b): return a@b
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
test_generic_square('gemm_unrolled_permute_r', N, f1, f2)
val_torch = torch_c.numpy().sum()
val_tinygrad = tiny_c.numpy().sum()
et_torch = np.median(ets_torch)
et_tinygrad = np.median(ets_tinygrad)
print(f"{N}x{N} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad)
relative_error = abs((val_tinygrad-val_torch)/val_torch)
assert relative_error < 0.01
def test_gemm_unrolled_permute_lr(self):
N = 512
def f1(a, b): return a.T@b
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
def test_conv2d(self):
torch.manual_seed(0)
for bs in [32]:
for in_chans in IN_CHANS:
for out_chans in [64]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 64 if device == 'cuda' else 32
src = torch.rand(bs, in_chans, img_size, img_size)
dat = src.clone().to(device)
src_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None)
conv = src_conv.to(device)
for out_chans in [32]:
img_size = 34
torch_dat = torch.rand(bs, in_chans, img_size, img_size)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(): return torch_conv(torch_dat)
def f2(): return tiny_conv(tiny_dat).realize()
with torch.no_grad():
val_torch = conv(dat).cpu().numpy().sum()
ets_torch = []
for _ in range(CNT):
dat += 1
st = time.monotonic()
val_torch = conv(dat).cpu().numpy().sum()
et_torch = (time.monotonic() - st) * 1000
ets_torch.append(et_torch)
val_torch, et_torch = test_speed(f1)
val_tinygrad, et_tinygrad = test_speed(f2)
Tensor.no_grad = False
dat = Tensor(src.numpy())
conv = Conv2d(in_chans, out_chans, 3, bias=None)
conv.weight = Tensor(src_conv.weight.detach().cpu().numpy())
val_tinygrad = conv(dat).numpy().sum()
ets_tinygrad = []
for _ in range(CNT):
dat += 1
dat.realize()
st = time.monotonic()
val_tinygrad = conv(dat).numpy().sum()
et_tinygrad = (time.monotonic() - st) * 1000
ets_tinygrad.append(et_tinygrad)
et_torch = np.median(ets_torch)
et_tinygrad = np.median(ets_tinygrad)
print(f"bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} {et_torch:7.2f} ms in torch({device}), {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad)
relative_error = abs((val_tinygrad-val_torch)/val_torch)
assert relative_error < 0.01
print(f"bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} {et_torch:7.2f} ms in torch, {et_tinygrad:7.2f} ms in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower", val_torch.sum(), val_tinygrad.sum())
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4)
if __name__ == '__main__':
unittest.main()

View File

@ -5,6 +5,7 @@ def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x): return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
def shape_to_axis(old_shape, new_shape):

View File

@ -87,6 +87,9 @@ class Tensor:
# TODO: remove use of numpy here
@classmethod
def zeros_like(cls, tensor, **kwargs): return cls.zeros(*tensor.shape, **kwargs)
@classmethod
def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs)
@ -301,7 +304,7 @@ class Tensor:
# TODO: fix the kwargs problem, then remove these (or not, since they now fix tuples)
def reshape(self, shape, *args): return self._reshape(shape=argfix(shape, *args))
def expand(self, shape, *args): return self._expand(shape=argfix(shape, *args))
def expand(self, shape, *args): return self._expand(shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args): return self._permute(order=argfix(order, *args))
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):