# ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations from enum import Enum, auto import functools from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode # these ops live here class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 @functools.lru_cache(maxsize=None) def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: assert len(shape) == len(strides) ret = [(shape[0], strides[0])] if len(shape) > 0 else [] for i in range(1, len(shape)): if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0): ret[-1] = (ret[-1][0] * shape[i], strides[i]) else: ret.append((shape[i], strides[i])) return tuple(ret) @functools.lru_cache(maxsize=None) def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))) @functools.lru_cache(maxsize=None) def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape)) class ViewInternal(NamedTuple): shape:Tuple[int, ...] strides:Tuple[int, ...] offset:int mask:Optional[Tuple[Tuple[int, int]]] contiguous:bool shape_strides:Tuple[Tuple[int, int], ...] @functools.lru_cache(maxsize=None) class View(ViewInternal): def __new__(cls, shape, strides=None, offset=0, mask=None): strides_from_shape = strides_for_shape(shape) strides = strides_from_shape if not strides else filter_strides(shape, strides) contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides)) def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__() def expr_node_mask(self, idx, valid=None) -> Node: expr = [valid] if valid is not None else [] if self.mask is not None: acc = 1 for ns,(x,y) in reversed(list(zip(self.shape, self.mask))): base = ((idx//acc) % ns) expr += [base >= x, base < y] acc *= ns return Variable.ands(expr) # generate an expression if you have a single idx variable def expr_node(self, idx=None) -> Node: if idx is None: idx = Variable('idx', 0, prod(self.shape)) ret: List[Node] = [Variable.num(self.offset)] if self.offset else [] acc = 1 for d,s in reversed(self.shape_strides): ret.append(((idx//acc)%d)*s) acc *= d return Variable.sum(ret) # generate an expression if you have a variable or expression for each index def expr_idxs(self, idxs) -> Node: assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}" return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) @functools.lru_cache(maxsize=None) def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node: assert len(idxs) == len(shape), "need an idx for all dimensions" acc = 1 ret = [] for tidx,d in reversed(list(zip(idxs, shape))): ret.append(tidx * acc) acc *= d return Variable.sum(ret) @functools.lru_cache(maxsize=None) def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: strides = [1] if shape else [] for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)]) @functools.lru_cache(maxsize=None) def view_from_shape(shape:Tuple[int, ...]) -> View: assert all(isinstance(x, int) for x in shape) return View(tuple(shape), strides_for_shape(shape)) @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: if vm2.mask: return None # this isn't supported yet mst = ShapeTracker(vm1.shape, [vm2, vm1]) strides = mst.real_strides() if None in strides: return None return View(vm1.shape, strides, mst.real_offset(), vm1.mask) @functools.lru_cache(maxsize=None) def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset # check if this is adding or removing 1s (only) # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]: new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1] new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape]) new_mask_tuple = None if mask: for x,y in zip(shape, mask): if x == 1 and y != (0, 1): new_mask_tuple = ((0,0),) * len(new_shape) break else: new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1] new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape]) return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False new_view = View(new_shape, strides_for_shape(new_shape)) if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") return new_view, True @functools.lru_cache(maxsize=None) def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]): return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)]) @functools.lru_cache(maxsize=None) def get_unsafe_resize_offset(strides, arg): return sum([s * x[0] for s, x in zip(strides,arg)]) class ShapeTracker: __slots__ = "views" def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)]) def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @property def shape(self) -> Tuple[int, ...]: return self.views[-1].shape @property def key(self) -> Tuple[View, ...]: return tuple(self.views) # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) # these are multiview strides, value is None if it's not a simple strided dimension # TODO: this can be shared code between simplify and merge_views def real_offset(self) -> int: real_offset, mask = self.expr_node(Variable('zero', 0, 0)) assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}" return real_offset.b # NOTE: if a stride is not always valid, it will be None def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]: if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] idx, valid = self.expr_idxs(idxs) ret: List[Optional[int]] = [None] * len(self.views[-1].shape) for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable): ret[idxs.index(this_dim.a)] = this_dim.b elif isinstance(this_dim, Variable): ret[idxs.index(this_dim)] = 1 idx_vars, valid_vars = idx.vars(), valid.vars() for i,tidx in enumerate(idxs): if tidx in valid_vars and not ignore_valid: ret[i] = None elif tidx not in idx_vars: ret[i] = 0 return tuple(ret) def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def _expr_idx(self, idx, valid): for v in reversed(self.views[0:-1]): valid = v.expr_node_mask(idx, valid) idx = v.expr_node(idx) return idx, valid def simplify(self): if len(self.views) >= 2: new_view = merge_views(self.views[-2], self.views[-1]) if new_view: if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}") self.views = self.views[:-2] + [new_view] self.simplify() def expr_idxs(self, idxs=None): if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] idx = self.views[-1].expr_idxs(tuple(idxs)) valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs))) return self._expr_idx(idx, valid) def expr_node(self, idx='idx'): if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1) return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx)) def needs_valid(self) -> bool: return any(v.mask is not None for v in self.views) # *** under this line are the movement ops *** def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None): offset = get_unsafe_resize_offset(self.views[-1].strides, arg) if self.views[-1].mask: # move the old mask nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)]) # merge the masks if we have two mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask) def pad(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape) if any(b or e for b, e in arg): zvarg, mask = get_pad_args(self.shape, arg) self.__unsafe_resize(zvarg, mask=mask) return self def shrink(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape) self.__unsafe_resize(arg) return self def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker: assert len(new_shape) == len(self.views[-1].shape) assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}" # NOTE: can the mask ever be (0,0)? mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask) return self def reshape(self, new_shape: Tuple[int, ...]): if self.views[-1].shape == new_shape: return self assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}" assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" new_view, extra = _reshape(self.views[-1], new_shape) if extra: self.views.append(new_view) else: self.views[-1] = new_view return self def permute(self, axis: Tuple[int, ...]): assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None) return self # except for the negative case, you can build this from the others. invertible in the negative case def stride(self, mul: Tuple[int, ...]): assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" strides = tuple([z*m for z,m in zip(self.views[-1].strides, mul)]) new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.views[-1].shape, mul)]) offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0]) mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask) return self # *** entry point for external *** def movement_op(self, op: MovementOps, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker: assert isinstance(arg, tuple) and (len(arg) == len(self.shape) or op == MovementOps.RESHAPE), f"arg {arg} for {op} doesn't match dim of shape {self.shape}" dispatch[op](self, arg) return self dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker.reshape, MovementOps.EXPAND: ShapeTracker.expand, MovementOps.PAD: ShapeTracker.pad, MovementOps.SHRINK: ShapeTracker.shrink, MovementOps.PERMUTE: ShapeTracker.permute, MovementOps.STRIDE: ShapeTracker.stride} # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Optional[List[List[int]]]: # Pre-allocate all groups. axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))] # Index for new_shape and axis_groups. i: int = 0 old_shape_i: int = 0 while old_shape_i < len(old_shape): # 1s exist in new_shape only will lead to empty axes group creations. if new_shape[i] == 1 and old_shape[old_shape_i] != 1: if i < len(new_shape) - 1: i += 1 else: if new_shape[i] % old_shape[old_shape_i] != 0 or prod([old_shape[x] for x in axis_groups[i]]) * old_shape[old_shape_i] > new_shape[i]: return None axis_groups[i].append(old_shape_i) # Move to next axes group if total size of all dimensions match. if prod([old_shape[x] for x in axis_groups[i]]) == new_shape[i]: if i < len(new_shape) - 1: i += 1 old_shape_i += 1 return axis_groups