1
0
Fork 0

strip whitespace

pull/1064/head
George Hotz 2023-06-27 10:11:43 -07:00
parent 23648538fa
commit c8d87eb8d4
2 changed files with 10 additions and 8 deletions

View File

@ -0,0 +1,2 @@
#!/bin/bash
find tinygrad -type f -name "*.py" -exec sed -i '' 's/ *$//' '{}' ';'

View File

@ -82,7 +82,7 @@ class Node:
if sub_node.__class__ is NumNode: num_node_sum += sub_node.b
else: new_nodes.append(sub_node)
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
new_nodes = Node.factorize(new_nodes)
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@ -163,26 +163,26 @@ class SumNode(RedNode):
factors: List[Node] = []
nofactor_mul: List[Node] = []
nofactor_nonmul: List[Node] = []
for x in self.flat_components:
for x in self.flat_components:
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
else: nofactor_nonmul.append(x)
if factors: # factor out largest possible gcd
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
if nofactor_mul and not nofactor_nonmul:
gcds = [gcd(x.b, b) for x in nofactor_mul]
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
else:
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
else:
else:
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
return Node.sum(factor_term + nofactor_term)
for m in nofactor_mul:
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
new_nodes = []
for x in self.nodes:
@ -190,7 +190,7 @@ class SumNode(RedNode):
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []