From teeny (#2426)
* changes from teenygrad work * support not supporting ImageDType/PtrDType * fixups from teenypull/2428/head
parent
9ae83fba04
commit
8ff2e13550
|
@ -5,7 +5,7 @@ from tinygrad.jit import TinyJit
|
|||
from tinygrad.nn.state import get_state_dict
|
||||
import json
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"]
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
|
|
|
@ -29,7 +29,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit
|
|||
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count, f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
|
||||
|
||||
class TestRealWorld(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from extra.utils import OSX, temp
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
|
||||
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
|
||||
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX
|
||||
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
|
||||
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
|
||||
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
|
||||
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
|
||||
|
@ -113,6 +112,7 @@ class TestBFloat16DType(unittest.TestCase):
|
|||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
def test_bf16_disk_write_read(self):
|
||||
from extra.utils import temp
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
|
@ -173,17 +173,20 @@ class TestBoolDtype(TestDType): DTYPE = dtypes.bool
|
|||
|
||||
class TestEqStrDType(unittest.TestCase):
|
||||
def test_image_ne(self):
|
||||
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
||||
assert dtypes.float == dtypes.float32, "float doesn't match?"
|
||||
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
||||
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
||||
def test_ptr_ne(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
# TODO: is this the wrong behavior?
|
||||
assert PtrDType(dtypes.float32) == dtypes.float32
|
||||
#assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
|
||||
#assert PtrDType(dtypes.float32) != dtypes.float32
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
|
@ -25,7 +26,7 @@ class TestJit(unittest.TestCase):
|
|||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert len(f.jit_cache) == 3
|
||||
assert len(f.jit_cache) == 3 or (len(f.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
|
||||
def test_nothing_jitted(self):
|
||||
@TinyJit
|
||||
|
|
|
@ -4,7 +4,7 @@ import math
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes, Context, NOOPT
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
|
||||
from tinygrad.ops import Device
|
||||
|
||||
if CI:
|
||||
|
@ -271,6 +271,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]])
|
||||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]])
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
|
@ -700,7 +701,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
x = Tensor.ones((4,3,6,6))
|
||||
x.reshape([])
|
||||
|
||||
|
@ -785,11 +786,6 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_simple_conv2d_noopt(self):
|
||||
# useful with IMAGE enabled
|
||||
with Context(NOOPT=1):
|
||||
self.test_simple_conv2d()
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no conv3d on images")
|
||||
def test_simple_conv3d(self):
|
||||
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
|
||||
|
@ -1166,7 +1162,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
x.repeat((2, 4))
|
||||
|
||||
np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy())
|
||||
|
|
|
@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase):
|
|||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
|
||||
expected = f(a, b).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)
|
||||
assert len(jf.jit_cache) == 2 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
|
@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase):
|
|||
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
|
||||
expected = f(q, k, v).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)
|
||||
assert len(jf.jit_cache) == 6 or (len(jf.jit_cache) == 1 and getattr(Device[Device.DEFAULT], "graph", None))
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
|
|
|
@ -115,7 +115,7 @@ class TestSymbolicReshape(unittest.TestCase):
|
|||
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10).bind(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
||||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
||||
|
|
|
@ -97,6 +97,7 @@ class Profiling(contextlib.ContextDecorator):
|
|||
|
||||
# **** tinygrad now supports dtypes! *****
|
||||
|
||||
# TODO: migrate this from NamedTuple -> dataclass
|
||||
class DType(NamedTuple):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
|
|
|
@ -121,6 +121,7 @@ class LazyBuffer:
|
|||
def base(self): return self._base if self._base is not None else self
|
||||
|
||||
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.is_unrealized_const() and self.st.contiguous
|
||||
|
||||
@property
|
||||
def realized(self): return self.base._realized
|
||||
|
|
|
@ -70,7 +70,7 @@ class View:
|
|||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> View:
|
||||
assert len(new_shape) == len(self.shape)
|
||||
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
||||
if 0 in self.shape:
|
||||
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
return View.create(new_shape)
|
||||
|
@ -106,7 +106,7 @@ class View:
|
|||
# check for the same size
|
||||
if all_int(self.shape):
|
||||
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
assert prod(self.shape) == prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]), f"size mismatched, can't reshape {self.shape=} -> {new_shape=}"
|
||||
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
|
|
@ -71,9 +71,9 @@ class Tensor:
|
|||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else:
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}")
|
||||
|
||||
# data is a LazyBuffer, but it might be on the wrong device
|
||||
if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -673,8 +673,8 @@ class Tensor:
|
|||
return (x, y)
|
||||
|
||||
def _to_float(self, x:Union[Tensor, float]):
|
||||
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \
|
||||
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
|
||||
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
|
|
Loading…
Reference in New Issue