Remove AS_STRIDED from shapetracker (#2216)
* very close * remove comment * negative strides working * almost everything passes * calculate offset with list comprehension * some cleanup * got disk load working * review suggestions * fix after merge * overlap working * did it * clean * fixed disk load * lint * mypy * removed as_strided * trying without simplify * added back simplify * make sure expanding to smaller shape * cleanup * removed comment * removed env file * trying whisper test again * onnx test sqlite issue * working on test * finished test * eliminate unnecessary shrink-then-pad * don't shrink buffer * added strides check * added to ci under linters * switch issue * allow symbolic stride * removed .env * isinstance * adjust strides for double expand * cleanup * needed to add type hint for mypy * set pythonpathpull/2317/head^2
parent
b8d460d203
commit
b64738e1d6
|
@ -53,6 +53,8 @@ jobs:
|
|||
run: python test/external/fuzz_symbolic.py
|
||||
- name: Fuzz Test shapetracker
|
||||
run: PYTHONPATH="." python test/external/fuzz_shapetracker.py
|
||||
- name: Test shapetracker to_movement_ops
|
||||
run: PYTHONPATH="." python extra/to_movement_ops.py
|
||||
- name: Use as an external package
|
||||
run: |
|
||||
mkdir $HOME/test_external_dir
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
import random
|
||||
from tqdm import tqdm
|
||||
from extra.optimization.helpers import load_worlds
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.ops import LazyOp, MovementOps, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Node, Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
def get_real_view(shape, strides, offset, mask):
|
||||
real_shape = tuple(y-x for x,y in mask) if mask else shape
|
||||
offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0)
|
||||
real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0)
|
||||
real_real_shape = [s for s,st in zip(real_shape, strides) if st]
|
||||
strides = [abs(st) if isinstance(st,int) else st for st in strides if st]
|
||||
return real_real_shape, strides, real_offset
|
||||
|
||||
def get_buffer_size(shape, strides, offset, mask):
|
||||
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
|
||||
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
|
||||
|
||||
def flatten_view(view: View):
|
||||
real_real_shape, strides, real_offset = get_real_view(view.shape, view.strides, view.offset, view.mask)
|
||||
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
|
||||
ordered_shape_strides, _ = sort_by_strides(real_real_shape, strides)
|
||||
ordered_shape_strides = [list(s) for s in ordered_shape_strides]
|
||||
if strides:
|
||||
i = 0
|
||||
while i < len(ordered_shape_strides):
|
||||
if i<len(ordered_shape_strides)-1 and ordered_shape_strides[i][1] == ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
|
||||
ordered_shape_strides[i+1][0] = ordered_shape_strides[i][0]*ordered_shape_strides[i+1][0]
|
||||
else: i += 1
|
||||
flat_shape = [shape_stride[0] for shape_stride in ordered_shape_strides]
|
||||
flat_strides = [shape_stride[1] for shape_stride in ordered_shape_strides]
|
||||
return (flat_shape, flat_strides, real_offset)
|
||||
return (real_real_shape, view.strides, real_offset)
|
||||
|
||||
def views_equivalent(v1: View, v2: View) -> bool:
|
||||
return v1 == v2 or flatten_view(v1) == flatten_view(v2)
|
||||
|
||||
|
||||
def st_equivalent(st: ShapeTracker, st_rebuilt: ShapeTracker):
|
||||
views = list(st.views)
|
||||
rebuilt_views = list(st_rebuilt.views)
|
||||
i = 0
|
||||
while i < len(views):
|
||||
view, rebuilt_view = views[i], rebuilt_views[i]
|
||||
if view == rebuilt_view:
|
||||
i += 1
|
||||
continue
|
||||
elif view.shape == rebuilt_view.shape:
|
||||
i += 1
|
||||
# hack to skip expands for overlapped strides
|
||||
else:
|
||||
rebuilt_views.pop(i)
|
||||
return True
|
||||
|
||||
def test_rebuild(st: ShapeTracker):
|
||||
rebuilt_st = ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),))
|
||||
for mop, arg in st.to_movement_ops():
|
||||
if mop == MovementOps.RESHAPE:
|
||||
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
|
||||
if arg == (-1,):
|
||||
rebuilt_st = rebuilt_st.reshape((prod(rebuilt_st.views[-1].shape),))
|
||||
else:
|
||||
rebuilt_st = rebuilt_st.reshape(arg)
|
||||
elif mop == MovementOps.PERMUTE:
|
||||
rebuilt_st = rebuilt_st.permute(arg)
|
||||
elif mop == MovementOps.EXPAND:
|
||||
if len(arg) != len(rebuilt_st.shape):
|
||||
rebuilt_st = rebuilt_st.reshape((1,*rebuilt_st.shape))
|
||||
rebuilt_st = rebuilt_st.expand(arg)
|
||||
elif mop == MovementOps.PAD:
|
||||
rebuilt_st = rebuilt_st.pad(arg)
|
||||
elif mop == MovementOps.SHRINK:
|
||||
rebuilt_st = rebuilt_st.shrink(arg)
|
||||
elif mop == MovementOps.STRIDE:
|
||||
rebuilt_st = rebuilt_st.stride(arg)
|
||||
else:
|
||||
raise Exception("invalid mop")
|
||||
rebuilt_st = rebuilt_st.simplify()
|
||||
if len(st.views) != len(rebuilt_st.views):
|
||||
if not set(st.views).issubset(set(rebuilt_st.views)):
|
||||
assert st_equivalent(st, rebuilt_st)
|
||||
else:
|
||||
for v1,v2 in zip(st.views, rebuilt_st.views):
|
||||
assert views_equivalent(v1, v2), f"{v1} not equivalent to {v2}"
|
||||
last_v1 = st.views[-1]
|
||||
last_v2 = rebuilt_st.views[-1]
|
||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(False, False, True)
|
||||
random.shuffle(ast_strs)
|
||||
ast_strs = ast_strs[:2000]
|
||||
def interpret_ast(ast):
|
||||
if ast.op in BufferOps:
|
||||
test_rebuild(ast.arg.st)
|
||||
else:
|
||||
for src in ast.src: interpret_ast(src)
|
||||
for ast_str in tqdm(ast_strs):
|
||||
ast = eval(ast_str)
|
||||
interpret_ast(ast)
|
|
@ -17,7 +17,7 @@ class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
|||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
|
||||
|
|
|
@ -41,7 +41,6 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
}}
|
||||
|
|
|
@ -24,18 +24,16 @@ class RawDiskBuffer(RawBufferMapped):
|
|||
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
||||
def shrink(self, arg):
|
||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
|
||||
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
|
||||
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
|
||||
|
||||
def as_strided(self, arg):
|
||||
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
|
||||
assert len(arg)<2 or arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
offset = arg[0][0]*(prod(self.shape[1:]) if len(arg)>1 else 1)*self.dtype.itemsize
|
||||
size = (arg[0][1]-arg[0][0]) * (prod(self.shape[1:]) if len(arg)>1 else 1)
|
||||
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+(self.shape[1:] if len(arg)>1 else ()))
|
||||
|
||||
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
|
||||
def readinto(self, buf):
|
||||
self._buf[0].seek(self.offset)
|
||||
self._buf[0].readinto(buf)
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.SHRINK: RawDiskBuffer.shrink, MovementOps.RESHAPE: RawDiskBuffer.reshape }
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op)
|
||||
|
||||
|
|
|
@ -24,5 +24,5 @@ class RawShmBuffer(RawBufferMapped):
|
|||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
# TODO: is this wrong?
|
||||
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
|
||||
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x }
|
||||
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op)
|
||||
|
|
|
@ -16,12 +16,6 @@ def match_types(x, y, disallow_bool=False):
|
|||
if disallow_bool and up == torch.bool: up = torch.float
|
||||
return x.type(up), y.type(up)
|
||||
|
||||
def as_strided(x, arg):
|
||||
if any(i < 0 for i in arg[1]):
|
||||
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
|
||||
arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0])
|
||||
return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2])
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
# TODO: torch.tensor should work here
|
||||
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
||||
|
@ -38,7 +32,6 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
|
||||
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg),
|
||||
MovementOps.AS_STRIDED: as_strided
|
||||
}}
|
||||
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
|
|
|
@ -97,14 +97,35 @@ class ShapeTracker:
|
|||
|
||||
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
for v in self.views:
|
||||
for i, v in enumerate(self.views):
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
# then, we make it the correct shape
|
||||
# then, we apply permutations
|
||||
# TODO: don't use as_strided
|
||||
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
|
||||
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
|
||||
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
|
||||
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
|
||||
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
|
||||
if i: buffer_size = prod(self.views[i-1].shape) - real_offset
|
||||
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
|
||||
ordered_shape_strides, order = sort_by_strides(real_real_shape, strides)
|
||||
to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))])
|
||||
if strides:
|
||||
if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),)))
|
||||
for i, shape_stride in enumerate(ordered_shape_strides):
|
||||
if i<len(ordered_shape_strides)-1 and shape_stride[1] < ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
|
||||
remaining_buffer = ordered_shape_strides[i-1][1] if i>0 else buffer_size
|
||||
to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer)))
|
||||
to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1)))
|
||||
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer)))
|
||||
to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1]))))
|
||||
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1])))
|
||||
ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1])
|
||||
else:
|
||||
to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1]))))
|
||||
to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1])))
|
||||
to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))])
|
||||
if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides)))))
|
||||
to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides))))
|
||||
if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides)))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
|
|
Loading…
Reference in New Issue