LazyBuffer.get_variable_buffers() (#1391)
* LazyBudder.get_variable_buffers() * remove left_only, add ProdNode * no vars for OpNode.b * do not change symbolic vars, remove ProdNodepull/1407/head
parent
8889821547
commit
18d0a93f09
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import unittest
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
def test_fromcpu_buffer_sharing(self):
|
||||
|
@ -43,5 +44,29 @@ class TestLazyBuffer(unittest.TestCase):
|
|||
z = Tensor([1, np.e]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
class TestVariableBuffer(unittest.TestCase):
|
||||
def test_get_variable_buffers_no_variable(self):
|
||||
t = Tensor.rand(2, 3)
|
||||
assert t.lazydata.get_variable_buffers() == {}
|
||||
|
||||
def test_get_variable_buffers_one_variable(self):
|
||||
v = Variable("v", 1, 10)
|
||||
t = Tensor.rand(2, 3).reshape(v, 3)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 2
|
||||
v = Variable("v", 1, 10)
|
||||
t = Tensor.rand(2, 3).reshape(2, v)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 3
|
||||
|
||||
def test_get_variable_buffers_cat(self):
|
||||
v1 = Variable("v1", 1, 10)
|
||||
v2 = Variable("v2", 1, 10)
|
||||
t1 = Tensor.rand(2, 3).reshape(v1, 3)
|
||||
t2 = Tensor.rand(6, 3).reshape(v2, 3)
|
||||
t = t1.cat(t2)
|
||||
buffers = t.lazydata.get_variable_buffers()
|
||||
assert len(buffers) == 2 and buffers[v1].realize().realized.toCPU() == 2 and buffers[v2].realize().realized.toCPU() == 6
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, sym_vars
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
|
@ -240,6 +240,36 @@ class TestSymbolicNumeric(unittest.TestCase):
|
|||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = NumNode(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == []
|
||||
assert a.vars() == a.vars() == [a]
|
||||
m = MulNode(a, 3)
|
||||
assert m.vars() == [a]
|
||||
s = SumNode([a, b, c])
|
||||
assert s.vars() == [a, b, c]
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
# TODO: update this after we support symbolic * symbolic
|
||||
assert (a + b * c).vars() == [a, b]
|
||||
assert (a % 3 + b // 5).vars() == [a, b]
|
||||
assert (a + b + c - a).vars() == [b, c]
|
||||
|
||||
def test_sym_vars(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
assert sym_vars(1) == []
|
||||
assert sym_vars(a) == [a]
|
||||
assert sym_vars(a+b) == [a, b]
|
||||
assert sym_vars(a*3) == [a]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten,
|
|||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sym_vars
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
|
||||
|
||||
|
@ -270,6 +271,7 @@ class LazyBuffer:
|
|||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[Any]: return []
|
||||
def get_variable_buffers(self) -> Dict[Variable, LazyBuffer]: return {v:LazyBuffer.loadop(LoadOps.FROM, (1,), dtypes.int32, self.device, src=LazyBuffer.fromCPU(np.array([v.val], dtype=np.int32))) for s in self.shape for v in sym_vars(s)}
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
|||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
|
||||
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
|
||||
|
||||
class Node:
|
||||
b: Union[Node, int]
|
||||
|
|
Loading…
Reference in New Issue