1
0
Fork 0
tinygrab/test/test_jit.py

312 lines
9.7 KiB
Python
Raw Permalink Normal View History

2023-02-11 11:04:03 -07:00
#!/usr/bin/env python
import unittest
import numpy as np
from test.helpers import assert_jit_cache_len
2023-11-16 21:54:57 -07:00
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
CI < 5 minutes (#1252) * models matrix * fix typo and install gpu deps * install llvm deps if needed * fix * testops with cuda * remove pip cache since not work * cuda env * install cuda deps * maybe it will work now * i can't read * all tests in matrix * trim down more * opencl stuff in matrix * opencl pip cache * test split * change cuda test exclusion * test * fix cuda maybe * add models * add more n=auto * third thing * fix bug * cache pip more * change name * update tests * try again cause why not * balance * try again... * try apt cache for cuda * try on gpu: * try cuda again * update packages step * replace libz-dev with zlib1g-dev * only cache cuda * why error * fix gpuocelot bug * apt cache err * apt cache to slow? * opt and image in single runner * add a couple n=autos * remove test matrix * try cuda apt cache again * libz-dev -> zlib1g-dev * remove -s since not supported by xdist * the cache takes too long and doesn't work * combine webgpu and metal tests * combine imagenet to c and cpu tests * torch tests with linters * torch back by itself * small windows clang test with torch tests * fix a goofy windows bug * im dumb * bro * clang with linters * fix pylint error * linter not work on windows * try with clang again * clang and imagenet? * install deps * fix * fix quote * clang by itself (windows too slow) * env vars for imagenet * cache pip for metal and webgpu tests * try torch with metal and webgpu * doesn't work, too long * remove -v * try -n=logical * don't use logical * revert accidental thing * remove some prints unless CI * fix print unless CI * ignore speed tests for slow tests * clang windows in matrix (ubuntu being tested in imagenet->c test) * try manual pip cache * fix windows pip cache path * all manual pip cache * fix pip cache dir for macos * print_ci function in helpers * CI as variable, no print_ci * missed one * cuda tests with docker image * remove setup-python action for cuda * python->python3? * remove -s -v * try fix pip cache * maybe fix * try to fix pip cache * is this the path? * maybe cache pip * try again * create wheels dir * ? * cuda pip deps in dockerfile * disable pip cache for clang * image from ghcr instead of docker hub * why is clang like this * fast deps * try use different caches * remove the fast thing * try with lighter image * remove setup python for cuda * small docker and cuda fast deps * ignore a few more tests * cool docker thing (maybe) * oops * quotes * fix docker command * fix bug * ignore train efficientnet test * remove dockerfile (docker stuff takes too long) * remove docker stuff and normal cuda * oops * ignore the tests for cuda * does this work * ignore test_train on slow backends * add space * llvm ignore same tests as cuda * nvm * ignore lr scheduler tests * get some stats * fix ignore bug * remove extra ' * remove and * ignore test for llvm * change ignored tests and durationon all backends * fix * and -> or * ignore some more cuda tests * finally? * does this fix it * remove durations=0 * add some more tests to llvm * make last pytest more readable * fix * don't train efficientnet on cpu * try w/out pip cache * pip cache seems to be generally better * pytest file markers * try apt fast for cuda * use quick install for apt-fast * apt-fast not worth * apt-get to apt * fix typo * suppress warnings * register markers * disable debug on fuzz tests * change marker names * apt update and apt install in one command * update marker names in test.yml * webgpu pytest marker
2023-07-23 14:00:56 -06:00
2023-12-04 22:01:04 -07:00
2023-02-11 11:04:03 -07:00
class TestJit(unittest.TestCase):
2023-12-04 22:01:04 -07:00
def test_simple_jit(self):
@TinyJit
def add(a, b):
return (a + b).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add(a, b)
np.testing.assert_allclose(
c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(add, 1)
def test_jit_multiple_outputs(self):
@TinyJit
def f(a, b):
return (a + b).realize(), (a - b).realize(), (a * b).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c, d, e = f(a, b)
np.testing.assert_allclose(
c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5
)
np.testing.assert_allclose(
d.numpy(), a.numpy() - b.numpy(), atol=1e-4, rtol=1e-5
)
np.testing.assert_allclose(
e.numpy(), a.numpy() * b.numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(f, 3)
def test_nothing_jitted(self):
@TinyJit
def add(a, b):
return a + b
with self.assertRaises(AssertionError):
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
add(a, b)
def test_jit_shape_mismatch(self):
@TinyJit
def add(a, b):
return (a + b).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
add(a, b)
bad = Tensor.randn(20, 20)
with self.assertRaises(AssertionError):
add(a, bad)
def test_jit_shape_views_mismatch(self):
@TinyJit
def add(a):
return (a + 1).realize()
with self.assertRaises(AssertionError):
for i in range(1, 5):
# a has an offset that the kernel doesn't know about
a = Tensor.randn(10, 10).realize()[:, i : i + 2]
add(a)
def test_jit_duplicate_fail(self):
# the jit doesn't support duplicate arguments
@TinyJit
def add(a, b):
return (a + b).realize()
a = Tensor.randn(10, 10)
2023-12-04 22:01:04 -07:00
with self.assertRaises(AssertionError):
add(a, a)
def test_kwargs_jit(self):
@TinyJit
def add_kwargs(first, second):
return (first + second).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_allclose(
c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(add_kwargs, 1)
def test_array_jit(self):
@TinyJit
def add_array(a, arr):
return (a + arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(
np.any(np.not_equal(c.numpy(), a.numpy() + b.numpy())),
True,
atol=1e-4,
rtol=1e-5,
)
else:
np.testing.assert_allclose(
c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(add_array, 1)
def test_method_jit(self):
class Fun:
def __init__(self):
self.a = Tensor.randn(10, 10)
@TinyJit
def __call__(self, b: Tensor) -> Tensor:
return (self.a + b).realize()
fun = Fun()
for _ in range(5):
b = Tensor.randn(10, 10)
c = fun(b)
np.testing.assert_allclose(
c.numpy(), fun.a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(fun.__call__.func.__self__, 1)
def test_jit_size1_input(self):
@TinyJit
def f(a, b):
return (a + b).realize()
a = Tensor([1, 2, 3])
for i in range(5):
np.testing.assert_allclose(
f(a, Tensor([i])).numpy(), (a + i).numpy(), atol=1e-4, rtol=1e-5
)
assert_jit_cache_len(f, 1)
def test_jit_output_non_tensor_fail(self):
@TinyJit
def f(a, b, i):
return (a + b).realize(), i
output1, output2 = [], []
expect1, expect2 = [], []
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
o1, o2 = f(a, b, i)
output1.append(o1.numpy().copy())
output2.append(o2)
expect1.append(a.numpy().copy() + b.numpy().copy())
expect2.append(i)
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert_jit_cache_len(f, 1)
def test_jit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
return ((a + b) * rn).realize()
a = Tensor.randn(
10, 10
).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor._seed = 1234
jf = TinyJit(f)
res = set()
for _ in range(5):
o1 = jf(a, b)
res.add(o1.numpy()[0][0])
assert len(res) == 5, "All values should be different, rand works in jit."
Tensor._seed = 1234
jf2 = TinyJit(f)
res2 = set()
for _ in range(5):
o1 = jf2(a, b)
res2.add(o1.numpy()[0][0])
assert len(res2) == 5, "All values should be different, rand works in jit."
assert res == res2, "Jit rand is not reproducible with the same seed"
Tensor._seed = 3421
jf3 = TinyJit(f)
res3 = set()
for _ in range(5):
o1 = jf3(a, b)
res3.add(o1.numpy()[0][0])
assert len(res3) == 5, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)
@TinyJit
def foo(x):
return w.dot(x).realize()
arg = [
Tensor([1, 2, 3, 4, 5]),
Tensor([1, 3, 3, 4, 6]),
Tensor([1, 2, 5, 4, 7]),
Tensor([0, 2, 3, 1, 0]),
]
Y = [foo(e).numpy() for e in arg]
foo(Tensor([7, 7, 7, 7, 7]))
want = [
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 3.0, 3.0, 4.0, 6.0],
[1.0, 2.0, 5.0, 4.0, 7.0],
[0.0, 2.0, 3.0, 1.0, 0.0],
]
np.testing.assert_allclose(want, Y)
def test_jitted_read_assign(self):
class Cache:
def __init__(self):
self.good_cache = Tensor.zeros(1)
self.bad_cache = Tensor.zeros(1)
self.good_jitted = TinyJit(self.good)
self.bad_jitted = TinyJit(self.bad)
def good(self, y, cache_v=None):
if cache_v is not None:
self.good_cache.assign(cache_v + 1 - 1).realize()
return (
self.good_cache + y
).realize() # need + y to provide inputs to JIT
def bad(self, y, cache_v=None):
if cache_v is not None:
self.bad_cache.assign(cache_v).realize()
return (self.bad_cache + y).realize()
cache = Cache()
np.testing.assert_equal([0], cache.good_cache.numpy())
np.testing.assert_equal([0], cache.bad_cache.numpy())
zero = Tensor([0])
one = Tensor([1])
two = Tensor([2])
# save [1] in the caches
cache.good(zero, one)
cache.bad(zero, one)
np.testing.assert_equal([1], cache.good_cache.numpy())
np.testing.assert_equal([1], cache.bad_cache.numpy())
for i in range(5):
cache.good_jitted(zero)
cache.bad_jitted(zero)
# verify the jitted calls read 1 from the cache
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
# save [2] in the caches
cache.good(zero, two)
cache.bad(zero, two)
np.testing.assert_equal([2], cache.good_cache)
np.testing.assert_equal([2], cache.bad_cache)
# verify the jitted calls read 2 from the cache
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
# but the bad_jitted doesn't!
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
assert_jit_cache_len(cache.good_jitted, 1)
assert_jit_cache_len(cache.bad_jitted, 1)
def test_jit_buffer_behavior(self):
@TinyJit
def foo(x) -> Tensor:
return x.sum().realize()
result_1 = foo(Tensor([1] * 2))
result_2 = foo(Tensor([2] * 2))
result_3 = foo(Tensor([3] * 2))
# expect the buffer to share underlying buffer
np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
if __name__ == "__main__":
unittest.main()