diff --git a/test/test_symbolic.py b/test/test_symbolic.py index 53158d668..727fdec03 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -73,8 +73,11 @@ class TestSymbolic(unittest.TestCase): def test_mod_factor(self): # this is technically wrong, if b is 0 the output will be negative - self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)") - self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)") + self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((a+-1)%28)") + self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((a+-1)%28)") + + def test_sum_combine_num(self): + self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable.num(23)]), -6, 4, "(a+-6)") def test_div_factor(self): # TODO: this isn't right diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 942333465..1964fd7d7 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -65,6 +65,10 @@ class Node: @staticmethod def sum(nodes:List[Node]) -> Node: + nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode)) + num_sum = sum([x.b for x in num_nodes]) + if num_sum != 0: nodes.append(NumNode(num_sum)) + if any([isinstance(x, SumNode) for x in nodes]): nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode)) for x in sum_nodes: nodes += x.nodes