1
0
Fork 0

sum_combine_num

pull/555/head
George Hotz 2023-02-11 14:48:31 -08:00
parent a4f5f2ff8b
commit 7a7046f264
2 changed files with 9 additions and 2 deletions

View File

@ -73,8 +73,11 @@ class TestSymbolic(unittest.TestCase):
def test_mod_factor(self):
# this is technically wrong, if b is 0 the output will be negative
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((a+-1)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((a+-1)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable.num(23)]), -6, 4, "(a+-6)")
def test_div_factor(self):
# TODO: this isn't right

View File

@ -65,6 +65,10 @@ class Node:
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
num_sum = sum([x.b for x in num_nodes])
if num_sum != 0: nodes.append(NumNode(num_sum))
if any([isinstance(x, SumNode) for x in nodes]):
nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))
for x in sum_nodes: nodes += x.nodes