1
0
Fork 0
* deleting lines

* remove insert dims

* if statement is never hit

* bug fixes
pull/1584/head
George Hotz 2023-08-20 08:12:16 -07:00 committed by GitHub
parent 4fbce972d7
commit 739f327d2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 26 additions and 33 deletions

3
.gitignore vendored
View File

@ -36,3 +36,6 @@ package.json
package-lock.json
temp
*.csv
.coverage
coverage.xml
htmlcov

View File

@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 967)
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
def test_llama(self):

View File

@ -92,7 +92,7 @@ class TestHalfDtype(unittest.TestCase):
def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)
@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends")
@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"] or OSX, "float64 is not supported by some backends")
class TestDoubleDtype(unittest.TestCase):
def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4])
def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64)

View File

@ -149,4 +149,7 @@ class TestSymbolicJit(unittest.TestCase):
a = Tensor.rand(3, 7).reshape(3, vi)
bad = Tensor.rand(4, 7).reshape(4, vi)
with self.assertRaises(AssertionError):
add(a, bad)
add(a, bad)
if __name__ == '__main__':
unittest.main()

View File

@ -110,4 +110,7 @@ class TestSymbolicOps(unittest.TestCase):
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
if __name__ == '__main__':
unittest.main()

View File

@ -116,8 +116,8 @@ def hand_coded_optimizations(k:Linearizer):
buf1 = k.bufs.index(k.reduceop.src[0].src[0].src[1])
buf0_strides = k.sts[buf0].real_strides()
buf1_strides = k.sts[buf1].real_strides()
axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%16 == 0]
axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%16 == 0]
axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%16 == 0 and i < k.first_reduce]
axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%16 == 0 and i < k.first_reduce]
if len(axis_buf0) and len(axis_buf1) and k.full_shape[k.first_reduce]%8 == 0 and (k.shape_len-k.first_reduce) == 1:
if DEBUG >= 3: print("HIP TENSOR CORES", axis_buf0, axis_buf1)
k.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA
@ -175,8 +175,8 @@ def hand_coded_optimizations(k:Linearizer):
buf1 = k.bufs.index(k.reduceop.src[0].src[1])
buf0_strides = k.sts[buf0].real_strides()
buf1_strides = k.sts[buf1].real_strides()
axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%8 == 0]
axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%8 == 0]
axis_buf0 = [(i,k.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and k.full_shape[i]%8 == 0 and i < k.first_reduce]
axis_buf1 = [(i,k.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and k.full_shape[i]%8 == 0 and i < k.first_reduce]
if len(axis_buf0) and len(axis_buf1) and k.full_shape[k.first_reduce]%8 == 0 and (k.shape_len-k.first_reduce) == 1:
if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1)
k.use_tensor_cores = getenv("TC", 1) == 1 # TC=2 will do the shape ops without the WMMA

View File

@ -222,9 +222,6 @@ class Tensor:
self.grad = Tensor(1, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if not t0.requires_grad:
del t0._ctx # TODO: does it help to delete this here ever?
continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
@ -251,12 +248,6 @@ class Tensor:
# ***** movement hlops *****
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - A slice i:j returns the elements with indices in [i, j)
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
@ -357,6 +348,12 @@ class Tensor:
ret = ret.permute(order=order)
return ret
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
def gather(self: Tensor, idx: Tensor, dim: int):
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
@ -459,7 +456,7 @@ class Tensor:
# ***** processing ops *****
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1, _insert_dims=tuple()) -> Tensor:
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
@ -467,10 +464,7 @@ class Tensor:
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)])
# NOTE: _insert_dims is required because reduces can't be merged (yet)
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
# slide by dilation
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
@ -483,11 +477,7 @@ class Tensor:
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
if len(_insert_dims):
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
prefix += _insert_dims
slc_prefix += [(0,x) for x in _insert_dims]
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
@ -518,12 +508,6 @@ class Tensor:
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# expand the channels with the pool
# TODO: this reduces the number of kernels, but it's slower!
#x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W)
#rcout, oy, ox = x.shape[2:5]
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))