1
0
Fork 0
tinygrab/tinygrad/shape/symbolic.py

245 lines
10 KiB
Python

from __future__ import annotations
from abc import abstractmethod
import functools
from math import gcd
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
class Node:
b: int
min: int
max: int
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
ret = ops[type(self)](self, ops, ctx)
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
return ret
def vars(self): return []
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int: return hash(self.key)
def __repr__(self): return "<"+self.key+">"
def __hash__(self): return self.hash
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return (-self) < (-b+1)
def __lt__(self, b:int):
lhs = self
if isinstance(lhs, SumNode):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if len(muls):
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = muls[0].b
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
if b%mul_gcd == 0:
all_others = Variable.sum(others)
#print(mul_gcd, muls, all_others)
if all_others.min >= 0 and all_others.max < mul_gcd:
# TODO: should we divide both by mul_gcd here?
lhs = Variable.sum(muls)
return create_node(LtNode(lhs, b))
def __mul__(self, b:int):
if b == 0: return NumNode(0)
if b == 1: return self
return create_node(MulNode(self, b))
# *** complex ops ***
def __floordiv__(self, b:int, factoring_allowed=True):
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min//b
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_node(DivNode(self, b))
def __mod__(self, b:int):
assert b > 0
if b == 1: return NumNode(0)
if self.min >= 0 and self.max < b: return self
if self.min < 0: return (self - ((self.min//b)*b)) % b
return create_node(ModNode(self, b))
@staticmethod
def num(num:int) -> NumNode: return NumNode(num)
@staticmethod
def factorize(nodes:List[Node]):
mul_groups: Dict[Node, int] = {}
for x in nodes:
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
mul_groups[a] = mul_groups.get(a, 0) + b
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if x.max or x.min]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]
new_nodes: List[Node] = []
num_node_sum = 0
# flatten all sumnodes and gather numnodes
for node in nodes:
if node.__class__ not in (NumNode, SumNode): new_nodes.append(node)
elif node.__class__ is NumNode: num_node_sum += node.b
elif isinstance(node, SumNode): # mypy wants the isinstance
for sub_node in node.flat_components:
if sub_node.__class__ is NumNode: num_node_sum += sub_node.b
else: new_nodes.append(sub_node)
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
new_nodes = Node.factorize(new_nodes)
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@staticmethod
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(x.min == x.max == 0 for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
def vars(self): return [self]
class NumNode(Node):
def __init__(self, num:int):
self.b, self.min, self.max = num, num, num
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
class OpNode(Node):
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars()
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
class LtNode(OpNode):
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
class MulNode(OpNode):
def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
class DivNode(OpNode):
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
return self.a.min//self.b, self.a.max//self.b
class ModNode(OpNode):
def __floordiv__(self, b: int, factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
def __floordiv__(self, b: int, factoring_allowed=True):
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
factors: List[Node] = []
nofactor_mul: List[Node] = []
nofactor_nonmul: List[Node] = []
for x in self.flat_components:
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
else: nofactor_nonmul.append(x)
if factors: # factor out largest possible gcd
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
if nofactor_mul and not nofactor_nonmul:
gcds = [gcd(x.b, b) for x in nofactor_mul]
if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul):
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
else:
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
else:
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
return Node.sum(factor_term + nofactor_term)
for m in nofactor_mul:
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
return new_nodes
class AndNode(RedNode):
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
ret = typ(nodes)
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}",
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
}