1
0
Fork 0

From teeny (#2426)

* changes from teenygrad work

* support not supporting ImageDType/PtrDType

* fixups from teeny
pull/2428/head
George Hotz 2023-11-24 12:50:56 -08:00 committed by GitHub
parent 9ae83fba04
commit 8ff2e13550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 24 additions and 22 deletions

View File

@ -5,7 +5,7 @@ from tinygrad.jit import TinyJit
from tinygrad.nn.state import get_state_dict
import json
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"]
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0

View File

@ -29,7 +29,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
if all_jitted:
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count, f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
class TestRealWorld(unittest.TestCase):
def setUp(self):

View File

@ -1,17 +1,16 @@
import unittest
import numpy as np
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX
from tinygrad.ops import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from extra.utils import OSX, temp
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
@ -113,6 +112,7 @@ class TestBFloat16DType(unittest.TestCase):
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
def test_bf16_disk_write_read(self):
from extra.utils import temp
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
@ -173,17 +173,20 @@ class TestBoolDtype(TestDType): DTYPE = dtypes.bool
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
assert dtypes.float == dtypes.float32, "float doesn't match?"
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior?
assert PtrDType(dtypes.float32) == dtypes.float32
#assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
#assert PtrDType(dtypes.float32) != dtypes.float32
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
@ -25,7 +26,7 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
assert len(f.jit_cache) == 3
assert len(f.jit_cache) == 3 or (len(f.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
def test_nothing_jitted(self):
@TinyJit

View File

@ -4,7 +4,7 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes, Context, NOOPT
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
from tinygrad.ops import Device
if CI:
@ -271,6 +271,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]])
def test_div_int(self):
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]])
def test_div_const(self):
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
@ -700,7 +701,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
x = Tensor.ones((4,3,6,6))
x.reshape([])
@ -785,11 +786,6 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_simple_conv2d_noopt(self):
# useful with IMAGE enabled
with Context(NOOPT=1):
self.test_simple_conv2d()
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
@ -1166,7 +1162,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
x.repeat((2, 4))
np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy())

View File

@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)
assert len(jf.jit_cache) == 2 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)
assert len(jf.jit_cache) == 6 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()

View File

@ -115,7 +115,7 @@ class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node

View File

@ -97,6 +97,7 @@ class Profiling(contextlib.ContextDecorator):
# **** tinygrad now supports dtypes! *****
# TODO: migrate this from NamedTuple -> dataclass
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int

View File

@ -121,6 +121,7 @@ class LazyBuffer:
def base(self): return self._base if self._base is not None else self
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
def is_unrealized_contiguous_const(self): return self.is_unrealized_const() and self.st.contiguous
@property
def realized(self): return self.base._realized

View File

@ -70,7 +70,7 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: Tuple[sint, ...]) -> View:
assert len(new_shape) == len(self.shape)
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
if 0 in self.shape:
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
return View.create(new_shape)
@ -106,7 +106,7 @@ class View:
# check for the same size
if all_int(self.shape):
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)

View File

@ -71,9 +71,9 @@ class Tensor:
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else:
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}")
# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self):
@ -673,8 +673,8 @@ class Tensor:
return (x, y)
def _to_float(self, x:Union[Tensor, float]):
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)