#!/usr/bin/env python import unittest import numpy as np from test.helpers import assert_jit_cache_len from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit def add(a, b): return (a + b).realize() for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c = add(a, b) np.testing.assert_allclose( c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(add, 1) def test_jit_multiple_outputs(self): @TinyJit def f(a, b): return (a + b).realize(), (a - b).realize(), (a * b).realize() for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c, d, e = f(a, b) np.testing.assert_allclose( c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 ) np.testing.assert_allclose( d.numpy(), a.numpy() - b.numpy(), atol=1e-4, rtol=1e-5 ) np.testing.assert_allclose( e.numpy(), a.numpy() * b.numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(f, 3) def test_nothing_jitted(self): @TinyJit def add(a, b): return a + b with self.assertRaises(AssertionError): for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) add(a, b) def test_jit_shape_mismatch(self): @TinyJit def add(a, b): return (a + b).realize() for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) add(a, b) bad = Tensor.randn(20, 20) with self.assertRaises(AssertionError): add(a, bad) def test_jit_shape_views_mismatch(self): @TinyJit def add(a): return (a + 1).realize() with self.assertRaises(AssertionError): for i in range(1, 5): # a has an offset that the kernel doesn't know about a = Tensor.randn(10, 10).realize()[:, i : i + 2] add(a) def test_jit_duplicate_fail(self): # the jit doesn't support duplicate arguments @TinyJit def add(a, b): return (a + b).realize() a = Tensor.randn(10, 10) with self.assertRaises(AssertionError): add(a, a) def test_kwargs_jit(self): @TinyJit def add_kwargs(first, second): return (first + second).realize() for _ in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c = add_kwargs(first=a, second=b) np.testing.assert_allclose( c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(add_kwargs, 1) def test_array_jit(self): @TinyJit def add_array(a, arr): return (a + arr[0]).realize() for i in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) a.realize(), b.realize() c = add_array(a, [b]) if i >= 2: # should fail once jitted since jit can't handle arrays np.testing.assert_allclose( np.any(np.not_equal(c.numpy(), a.numpy() + b.numpy())), True, atol=1e-4, rtol=1e-5, ) else: np.testing.assert_allclose( c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(add_array, 1) def test_method_jit(self): class Fun: def __init__(self): self.a = Tensor.randn(10, 10) @TinyJit def __call__(self, b: Tensor) -> Tensor: return (self.a + b).realize() fun = Fun() for _ in range(5): b = Tensor.randn(10, 10) c = fun(b) np.testing.assert_allclose( c.numpy(), fun.a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(fun.__call__.func.__self__, 1) def test_jit_size1_input(self): @TinyJit def f(a, b): return (a + b).realize() a = Tensor([1, 2, 3]) for i in range(5): np.testing.assert_allclose( f(a, Tensor([i])).numpy(), (a + i).numpy(), atol=1e-4, rtol=1e-5 ) assert_jit_cache_len(f, 1) def test_jit_output_non_tensor_fail(self): @TinyJit def f(a, b, i): return (a + b).realize(), i output1, output2 = [], [] expect1, expect2 = [], [] for i in range(5): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) o1, o2 = f(a, b, i) output1.append(o1.numpy().copy()) output2.append(o2) expect1.append(a.numpy().copy() + b.numpy().copy()) expect2.append(i) np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5) # the jit only works with Tensor outputs assert output2 != expect2 assert_jit_cache_len(f, 1) def test_jit_random_regen(self): def f(a, b): rn = Tensor.randn(*a.shape) return ((a + b) * rn).realize() a = Tensor.randn( 10, 10 ).realize() # realize these before resetting the random seed b = Tensor.randn(10, 10).realize() Tensor._seed = 1234 jf = TinyJit(f) res = set() for _ in range(5): o1 = jf(a, b) res.add(o1.numpy()[0][0]) assert len(res) == 5, "All values should be different, rand works in jit." Tensor._seed = 1234 jf2 = TinyJit(f) res2 = set() for _ in range(5): o1 = jf2(a, b) res2.add(o1.numpy()[0][0]) assert len(res2) == 5, "All values should be different, rand works in jit." assert res == res2, "Jit rand is not reproducible with the same seed" Tensor._seed = 3421 jf3 = TinyJit(f) res3 = set() for _ in range(5): o1 = jf3(a, b) res3.add(o1.numpy()[0][0]) assert len(res3) == 5, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" def test_jit_realization_and_sampling(self): w = Tensor.eye(5) @TinyJit def foo(x): return w.dot(x).realize() arg = [ Tensor([1, 2, 3, 4, 5]), Tensor([1, 3, 3, 4, 6]), Tensor([1, 2, 5, 4, 7]), Tensor([0, 2, 3, 1, 0]), ] Y = [foo(e).numpy() for e in arg] foo(Tensor([7, 7, 7, 7, 7])) want = [ [1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 3.0, 3.0, 4.0, 6.0], [1.0, 2.0, 5.0, 4.0, 7.0], [0.0, 2.0, 3.0, 1.0, 0.0], ] np.testing.assert_allclose(want, Y) def test_jitted_read_assign(self): class Cache: def __init__(self): self.good_cache = Tensor.zeros(1) self.bad_cache = Tensor.zeros(1) self.good_jitted = TinyJit(self.good) self.bad_jitted = TinyJit(self.bad) def good(self, y, cache_v=None): if cache_v is not None: self.good_cache.assign(cache_v + 1 - 1).realize() return ( self.good_cache + y ).realize() # need + y to provide inputs to JIT def bad(self, y, cache_v=None): if cache_v is not None: self.bad_cache.assign(cache_v).realize() return (self.bad_cache + y).realize() cache = Cache() np.testing.assert_equal([0], cache.good_cache.numpy()) np.testing.assert_equal([0], cache.bad_cache.numpy()) zero = Tensor([0]) one = Tensor([1]) two = Tensor([2]) # save [1] in the caches cache.good(zero, one) cache.bad(zero, one) np.testing.assert_equal([1], cache.good_cache.numpy()) np.testing.assert_equal([1], cache.bad_cache.numpy()) for i in range(5): cache.good_jitted(zero) cache.bad_jitted(zero) # verify the jitted calls read 1 from the cache np.testing.assert_equal([1], cache.good_jitted(zero).numpy()) np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) # save [2] in the caches cache.good(zero, two) cache.bad(zero, two) np.testing.assert_equal([2], cache.good_cache) np.testing.assert_equal([2], cache.bad_cache) # verify the jitted calls read 2 from the cache np.testing.assert_equal([2], cache.good_jitted(zero).numpy()) # but the bad_jitted doesn't! np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) assert_jit_cache_len(cache.good_jitted, 1) assert_jit_cache_len(cache.bad_jitted, 1) def test_jit_buffer_behavior(self): @TinyJit def foo(x) -> Tensor: return x.sum().realize() result_1 = foo(Tensor([1] * 2)) result_2 = foo(Tensor([2] * 2)) result_3 = foo(Tensor([3] * 2)) # expect the buffer to share underlying buffer np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5) np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5) np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5) if __name__ == "__main__": unittest.main()