291 lines
14 KiB
Python
291 lines
14 KiB
Python
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
|
from __future__ import annotations
|
|
from enum import Enum, auto
|
|
import functools
|
|
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
|
|
from tinygrad.helpers import prod, DEBUG
|
|
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
|
|
|
# these ops live here
|
|
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
|
assert len(shape) == len(strides)
|
|
ret = [(shape[0], strides[0])] if len(shape) > 0 else []
|
|
for i in range(1, len(shape)):
|
|
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
|
|
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
|
else:
|
|
ret.append((shape[i], strides[i]))
|
|
return tuple(ret)
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
|
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
|
|
|
class ViewInternal(NamedTuple):
|
|
shape:Tuple[int, ...]
|
|
strides:Tuple[int, ...]
|
|
offset:int
|
|
mask:Optional[Tuple[Tuple[int, int]]]
|
|
contiguous:bool
|
|
shape_strides:Tuple[Tuple[int, int], ...]
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
class View(ViewInternal):
|
|
def __new__(cls, shape, strides=None, offset=0, mask=None):
|
|
strides_from_shape = strides_for_shape(shape)
|
|
strides = strides_from_shape if not strides else filter_strides(shape, strides)
|
|
contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None
|
|
return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides))
|
|
def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__()
|
|
|
|
def expr_node_mask(self, idx, valid=None) -> Node:
|
|
expr = [valid] if valid is not None else []
|
|
if self.mask is not None:
|
|
acc = 1
|
|
for ns,(x,y) in reversed(list(zip(self.shape, self.mask))):
|
|
base = ((idx//acc) % ns)
|
|
expr += [base >= x, base < y]
|
|
acc *= ns
|
|
return Variable.ands(expr)
|
|
|
|
# generate an expression if you have a single idx variable
|
|
def expr_node(self, idx=None) -> Node:
|
|
if idx is None: idx = Variable('idx', 0, prod(self.shape))
|
|
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
|
|
acc = 1
|
|
for d,s in reversed(self.shape_strides):
|
|
ret.append(((idx//acc)%d)*s)
|
|
acc *= d
|
|
return Variable.sum(ret)
|
|
|
|
# generate an expression if you have a variable or expression for each index
|
|
def expr_idxs(self, idxs) -> Node:
|
|
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
|
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
|
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
|
acc = 1
|
|
ret = []
|
|
for tidx,d in reversed(list(zip(idxs, shape))):
|
|
ret.append(tidx * acc)
|
|
acc *= d
|
|
return Variable.sum(ret)
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
|
strides = [1] if shape else []
|
|
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
|
return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)])
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def view_from_shape(shape:Tuple[int, ...]) -> View:
|
|
assert all(isinstance(x, int) for x in shape)
|
|
return View(tuple(shape), strides_for_shape(shape))
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
|
if vm2.mask: return None # this isn't supported yet
|
|
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
|
strides = mst.real_strides()
|
|
if None in strides: return None
|
|
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
|
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
|
|
# check if this is adding or removing 1s (only)
|
|
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
|
if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
|
|
new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
|
|
new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
|
new_mask_tuple = None
|
|
if mask:
|
|
for x,y in zip(shape, mask):
|
|
if x == 1 and y != (0, 1):
|
|
new_mask_tuple = ((0,0),) * len(new_shape)
|
|
break
|
|
else:
|
|
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
|
|
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
|
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
|
|
|
|
new_view = View(new_shape, strides_for_shape(new_shape))
|
|
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
|
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
|
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
|
return new_view, True
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
|
|
return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)])
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def get_unsafe_resize_offset(strides, arg):
|
|
return sum([s * x[0] for s, x in zip(strides,arg)])
|
|
|
|
class ShapeTracker:
|
|
__slots__ = "views"
|
|
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
|
|
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)])
|
|
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
|
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
|
|
|
@property
|
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
|
|
|
@property
|
|
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
|
|
|
|
@property
|
|
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
|
|
|
# this is the real size (ish)
|
|
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
|
|
|
# these are multiview strides, value is None if it's not a simple strided dimension
|
|
# TODO: this can be shared code between simplify and merge_views
|
|
def real_offset(self) -> int:
|
|
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
|
|
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
|
|
return real_offset.b
|
|
|
|
# NOTE: if a stride is not always valid, it will be None
|
|
def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]:
|
|
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
|
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
|
idx, valid = self.expr_idxs(idxs)
|
|
ret: List[Optional[int]] = [None] * len(self.views[-1].shape)
|
|
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
|
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
|
ret[idxs.index(this_dim.a)] = this_dim.b
|
|
elif isinstance(this_dim, Variable):
|
|
ret[idxs.index(this_dim)] = 1
|
|
idx_vars, valid_vars = idx.vars(), valid.vars()
|
|
for i,tidx in enumerate(idxs):
|
|
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
|
elif tidx not in idx_vars: ret[i] = 0
|
|
return tuple(ret)
|
|
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
|
|
|
def _expr_idx(self, idx, valid):
|
|
for v in reversed(self.views[0:-1]):
|
|
valid = v.expr_node_mask(idx, valid)
|
|
idx = v.expr_node(idx)
|
|
return idx, valid
|
|
|
|
def simplify(self):
|
|
if len(self.views) >= 2:
|
|
new_view = merge_views(self.views[-2], self.views[-1])
|
|
if new_view:
|
|
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
|
self.views = self.views[:-2] + [new_view]
|
|
self.simplify()
|
|
|
|
def expr_idxs(self, idxs=None):
|
|
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
|
idx = self.views[-1].expr_idxs(tuple(idxs))
|
|
valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
|
return self._expr_idx(idx, valid)
|
|
|
|
def expr_node(self, idx='idx'):
|
|
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
|
|
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
|
|
|
|
def needs_valid(self) -> bool:
|
|
return any(v.mask is not None for v in self.views)
|
|
|
|
# *** under this line are the movement ops ***
|
|
|
|
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):
|
|
offset = get_unsafe_resize_offset(self.views[-1].strides, arg)
|
|
if self.views[-1].mask:
|
|
# move the old mask
|
|
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)])
|
|
# merge the masks if we have two
|
|
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
|
self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask)
|
|
|
|
def pad(self, arg: Tuple[Tuple[int, int], ...]):
|
|
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
|
if any(b or e for b, e in arg):
|
|
zvarg, mask = get_pad_args(self.shape, arg)
|
|
self.__unsafe_resize(zvarg, mask=mask)
|
|
return self
|
|
|
|
def shrink(self, arg: Tuple[Tuple[int, int], ...]):
|
|
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
|
self.__unsafe_resize(arg)
|
|
return self
|
|
|
|
def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker:
|
|
assert len(new_shape) == len(self.views[-1].shape)
|
|
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
|
|
# NOTE: can the mask ever be (0,0)?
|
|
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
|
|
self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)
|
|
return self
|
|
|
|
def reshape(self, new_shape: Tuple[int, ...]):
|
|
if self.views[-1].shape == new_shape: return self
|
|
assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}"
|
|
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
|
new_view, extra = _reshape(self.views[-1], new_shape)
|
|
if extra: self.views.append(new_view)
|
|
else: self.views[-1] = new_view
|
|
return self
|
|
|
|
def permute(self, axis: Tuple[int, ...]):
|
|
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
|
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
|
self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None)
|
|
return self
|
|
|
|
# except for the negative case, you can build this from the others. invertible in the negative case
|
|
def stride(self, mul: Tuple[int, ...]):
|
|
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
|
strides = tuple([z*m for z,m in zip(self.views[-1].strides, mul)])
|
|
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.views[-1].shape, mul)])
|
|
offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0])
|
|
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None
|
|
self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask)
|
|
return self
|
|
|
|
# *** entry point for external ***
|
|
|
|
def movement_op(self, op: MovementOps, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker:
|
|
assert isinstance(arg, tuple) and (len(arg) == len(self.shape) or op == MovementOps.RESHAPE), f"arg {arg} for {op} doesn't match dim of shape {self.shape}"
|
|
dispatch[op](self, arg)
|
|
return self
|
|
|
|
dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker.reshape, MovementOps.EXPAND: ShapeTracker.expand, MovementOps.PAD: ShapeTracker.pad,
|
|
MovementOps.SHRINK: ShapeTracker.shrink, MovementOps.PERMUTE: ShapeTracker.permute, MovementOps.STRIDE: ShapeTracker.stride}
|
|
|
|
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
|
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Optional[List[List[int]]]:
|
|
# Pre-allocate all groups.
|
|
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
|
# Index for new_shape and axis_groups.
|
|
i: int = 0
|
|
old_shape_i: int = 0
|
|
while old_shape_i < len(old_shape):
|
|
# 1s exist in new_shape only will lead to empty axes group creations.
|
|
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
|
if i < len(new_shape) - 1: i += 1
|
|
else:
|
|
if new_shape[i] % old_shape[old_shape_i] != 0 or prod([old_shape[x] for x in axis_groups[i]]) * old_shape[old_shape_i] > new_shape[i]:
|
|
return None
|
|
axis_groups[i].append(old_shape_i)
|
|
# Move to next axes group if total size of all dimensions match.
|
|
if prod([old_shape[x] for x in axis_groups[i]]) == new_shape[i]:
|
|
if i < len(new_shape) - 1: i += 1
|
|
old_shape_i += 1
|
|
return axis_groups
|