1
0
Fork 0

LazyBuffer.get_variable_buffers() (#1391)

* LazyBudder.get_variable_buffers()

* remove left_only, add ProdNode

* no vars for OpNode.b

* do not change symbolic vars, remove ProdNode
pull/1407/head
chenyu 2023-08-02 09:01:35 -07:00 committed by GitHub
parent 8889821547
commit 18d0a93f09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 1 deletions

View File

@ -3,6 +3,7 @@ import numpy as np
import unittest import unittest
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.shape.symbolic import Variable
class TestLazyBuffer(unittest.TestCase): class TestLazyBuffer(unittest.TestCase):
def test_fromcpu_buffer_sharing(self): def test_fromcpu_buffer_sharing(self):
@ -43,5 +44,29 @@ class TestLazyBuffer(unittest.TestCase):
z = Tensor([1, np.e]).numpy() z = Tensor([1, np.e]).numpy()
np.testing.assert_allclose(y, z) np.testing.assert_allclose(y, z)
class TestVariableBuffer(unittest.TestCase):
def test_get_variable_buffers_no_variable(self):
t = Tensor.rand(2, 3)
assert t.lazydata.get_variable_buffers() == {}
def test_get_variable_buffers_one_variable(self):
v = Variable("v", 1, 10)
t = Tensor.rand(2, 3).reshape(v, 3)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 2
v = Variable("v", 1, 10)
t = Tensor.rand(2, 3).reshape(2, v)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 3
def test_get_variable_buffers_cat(self):
v1 = Variable("v1", 1, 10)
v2 = Variable("v2", 1, 10)
t1 = Tensor.rand(2, 3).reshape(v1, 3)
t2 = Tensor.rand(6, 3).reshape(v2, 3)
t = t1.cat(t2)
buffers = t.lazydata.get_variable_buffers()
assert len(buffers) == 2 and buffers[v1].realize().realized.toCPU() == 2 and buffers[v2].realize().realized.toCPU() == 6
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest import unittest
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, sym_vars
class TestSymbolic(unittest.TestCase): class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s): def helper_test_variable(self, v, n, m, s):
@ -240,6 +240,36 @@ class TestSymbolicNumeric(unittest.TestCase):
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4) def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4) def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
class TestSymbolicVars(unittest.TestCase):
def test_simple(self):
z = NumNode(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == []
assert a.vars() == a.vars() == [a]
m = MulNode(a, 3)
assert m.vars() == [a]
s = SumNode([a, b, c])
assert s.vars() == [a, b, c]
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
# TODO: update this after we support symbolic * symbolic
assert (a + b * c).vars() == [a, b]
assert (a % 3 + b // 5).vars() == [a, b]
assert (a + b + c - a).vars() == [b, c]
def test_sym_vars(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
assert sym_vars(1) == []
assert sym_vars(a) == [a]
assert sym_vars(a+b) == [a, b]
assert sym_vars(a*3) == [a]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -9,6 +9,7 @@ from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten,
from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Variable, sym_vars
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
@ -270,6 +271,7 @@ class LazyBuffer:
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self) def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[Any]: return [] def get_lazyops(self) -> List[Any]: return []
def get_variable_buffers(self) -> Dict[Variable, LazyBuffer]: return {v:LazyBuffer.loadop(LoadOps.FROM, (1,), dtypes.int32, self.device, src=LazyBuffer.fromCPU(np.array([v.val], dtype=np.int32))) for s in self.shape for v in sym_vars(s)}
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)

View File

@ -9,6 +9,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node) def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
class Node: class Node:
b: Union[Node, int] b: Union[Node, int]