1
0
Fork 0

Node.vars() returns a set and properly dedup (#2356)

* dedup RedNode.vars()

* vars returns a set

* fix more vars

* unused import

* update to_movement_ops

* comment
pull/2358/head^2
chenyu 2023-11-18 17:44:52 -05:00 committed by GitHub
parent 0443cbfbb9
commit d7d078c7f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 30 additions and 24 deletions

View File

@ -71,8 +71,8 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
# always invalid
if valid1 == 0 and valid2 == 0: return True
var1 = set(idx1.vars() + valid1.vars())
var2 = set(idx2.vars() + valid2.vars())
var1 = idx1.vars() | valid1.vars()
var2 = idx2.vars() | valid2.vars()
# Maybe there are cases that vars are different yet the sts are the same?
if var1 != var2: return False

View File

@ -267,20 +267,25 @@ class TestSymbolicVars(unittest.TestCase):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == []
assert a.vars() == a.vars() == [a]
assert z.vars() == z.vars() == set()
assert a.vars() == a.vars() == {a}
m = MulNode(a, 3)
assert m.vars() == [a]
assert m.vars() == {a}
s = SumNode([a, b, c])
assert s.vars() == [a, b, c]
assert s.vars() == {a, b, c}
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert (a + b * c).vars() == [a, b, c]
assert (a % 3 + b // 5).vars() == [a, b]
assert (a + b + c - a).vars() == [b, c]
assert (a + b * c).vars() == {a, b, c}
assert (a % 3 + b // 5).vars() == {a, b}
assert (a + b + c - a).vars() == {b, c}
def test_dedup(self):
a = Variable("a", 0, 10)
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):

View File

@ -162,7 +162,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup
if valid.min == 0 and isinstance(idxy, SumNode):
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
val_dict: Dict[Node, Any] = {}
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
# TODO: is this correct? should it check there's only one variable from each component?
idxy_flat_var = [(i, list(i.vars())[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
for node in nodes:
assert isinstance(node, LtNode)

View File

@ -80,7 +80,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
def vars_from_ast(ast:LazyOp) -> Set[Variable]: return functools.reduce(operator.or_, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set())
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):

View File

@ -180,7 +180,7 @@ class BatchExecutor:
class ASTRunner:
def __init__(self, ast:Optional[LazyOp]):
if ast is None:
self.op_estimate, self.mem_estimate, self.vars = 0, 0, []
self.op_estimate, self.mem_estimate, self.vars = 0, 0, set()
else:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate

View File

@ -2,9 +2,9 @@
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG, dedup, merge_dicts
from tinygrad.helpers import prod, DEBUG, merge_dicts
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View
@ -79,7 +79,7 @@ class ShapeTracker:
def size(self): return 0 if (0 in self.shape) else self.expr_idxs()[0].max+1
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
def vars(self) -> Set[Variable]: return functools.reduce(operator.or_, [v.vars() for v in self.views], set())
@property
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])

View File

@ -3,7 +3,7 @@ import functools
from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator, Set
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@ -18,7 +18,7 @@ class Node:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self): return []
def vars(self) -> Set[Variable]: return set()
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
@ -149,7 +149,7 @@ class Variable(Node):
def unbind(self) -> Tuple[Variable, int]:
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
def vars(self): return [self]
def vars(self): return {self}
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
class NumNode(Node):
@ -173,7 +173,7 @@ class OpNode(Node):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
class LtNode(OpNode):
@ -221,7 +221,7 @@ class ModNode(OpNode):
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
def vars(self) -> Set[Variable]: return functools.reduce(lambda l,x: l | x.vars(), self.nodes, set())
class SumNode(RedNode):
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int, dedup
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, sint
from tinygrad.helpers import prod, all_int
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, Set, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
@ -30,9 +30,9 @@ class View:
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
return View(shape, strides, offset, mask, contiguous)
def vars(self) -> List[Variable]:
def vars(self) -> Set[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
def unbind(self) -> View:
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}