1
0
Fork 0

Remove AS_STRIDED from shapetracker (#2216)

* very close

* remove comment

* negative strides working

* almost everything passes

* calculate offset with list comprehension

* some cleanup

* got disk load working

* review suggestions

* fix after merge

* overlap working

* did it

* clean

* fixed disk load

* lint

* mypy

* removed as_strided

* trying without simplify

* added back simplify

* make sure expanding to smaller shape

* cleanup

* removed comment

* removed env file

* trying whisper test again

* onnx test sqlite issue

* working on test

* finished test

* eliminate unnecessary shrink-then-pad

* don't shrink buffer

* added strides check

* added to ci under linters

* switch issue

* allow symbolic stride

* removed .env

* isinstance

* adjust strides for double expand

* cleanup

* needed to add type hint for mypy

* set pythonpath
pull/2317/head^2
forcefieldsovereign 2023-11-15 12:50:17 -08:00 committed by GitHub
parent b8d460d203
commit b64738e1d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 143 additions and 25 deletions

View File

@ -53,6 +53,8 @@ jobs:
run: python test/external/fuzz_symbolic.py
- name: Fuzz Test shapetracker
run: PYTHONPATH="." python test/external/fuzz_shapetracker.py
- name: Test shapetracker to_movement_ops
run: PYTHONPATH="." python extra/to_movement_ops.py
- name: Use as an external package
run: |
mkdir $HOME/test_external_dir

View File

@ -0,0 +1,105 @@
import random
from tqdm import tqdm
from extra.optimization.helpers import load_worlds
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.ops import LazyOp, MovementOps, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes, prod
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Node, Variable
inf, nan = float('inf'), float('nan')
def get_real_view(shape, strides, offset, mask):
real_shape = tuple(y-x for x,y in mask) if mask else shape
offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0)
real_real_shape = [s for s,st in zip(real_shape, strides) if st]
strides = [abs(st) if isinstance(st,int) else st for st in strides if st]
return real_real_shape, strides, real_offset
def get_buffer_size(shape, strides, offset, mask):
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
def flatten_view(view: View):
real_real_shape, strides, real_offset = get_real_view(view.shape, view.strides, view.offset, view.mask)
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
ordered_shape_strides, _ = sort_by_strides(real_real_shape, strides)
ordered_shape_strides = [list(s) for s in ordered_shape_strides]
if strides:
i = 0
while i < len(ordered_shape_strides):
if i<len(ordered_shape_strides)-1 and ordered_shape_strides[i][1] == ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
ordered_shape_strides[i+1][0] = ordered_shape_strides[i][0]*ordered_shape_strides[i+1][0]
else: i += 1
flat_shape = [shape_stride[0] for shape_stride in ordered_shape_strides]
flat_strides = [shape_stride[1] for shape_stride in ordered_shape_strides]
return (flat_shape, flat_strides, real_offset)
return (real_real_shape, view.strides, real_offset)
def views_equivalent(v1: View, v2: View) -> bool:
return v1 == v2 or flatten_view(v1) == flatten_view(v2)
def st_equivalent(st: ShapeTracker, st_rebuilt: ShapeTracker):
views = list(st.views)
rebuilt_views = list(st_rebuilt.views)
i = 0
while i < len(views):
view, rebuilt_view = views[i], rebuilt_views[i]
if view == rebuilt_view:
i += 1
continue
elif view.shape == rebuilt_view.shape:
i += 1
# hack to skip expands for overlapped strides
else:
rebuilt_views.pop(i)
return True
def test_rebuild(st: ShapeTracker):
rebuilt_st = ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),))
for mop, arg in st.to_movement_ops():
if mop == MovementOps.RESHAPE:
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
if arg == (-1,):
rebuilt_st = rebuilt_st.reshape((prod(rebuilt_st.views[-1].shape),))
else:
rebuilt_st = rebuilt_st.reshape(arg)
elif mop == MovementOps.PERMUTE:
rebuilt_st = rebuilt_st.permute(arg)
elif mop == MovementOps.EXPAND:
if len(arg) != len(rebuilt_st.shape):
rebuilt_st = rebuilt_st.reshape((1,*rebuilt_st.shape))
rebuilt_st = rebuilt_st.expand(arg)
elif mop == MovementOps.PAD:
rebuilt_st = rebuilt_st.pad(arg)
elif mop == MovementOps.SHRINK:
rebuilt_st = rebuilt_st.shrink(arg)
elif mop == MovementOps.STRIDE:
rebuilt_st = rebuilt_st.stride(arg)
else:
raise Exception("invalid mop")
rebuilt_st = rebuilt_st.simplify()
if len(st.views) != len(rebuilt_st.views):
if not set(st.views).issubset(set(rebuilt_st.views)):
assert st_equivalent(st, rebuilt_st)
else:
for v1,v2 in zip(st.views, rebuilt_st.views):
assert views_equivalent(v1, v2), f"{v1} not equivalent to {v2}"
last_v1 = st.views[-1]
last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
if __name__ == "__main__":
ast_strs = load_worlds(False, False, True)
random.shuffle(ast_strs)
ast_strs = ast_strs[:2000]
def interpret_ast(ast):
if ast.op in BufferOps:
test_rebuild(ast.arg.st)
else:
for src in ast.src: interpret_ast(src)
for ast_str in tqdm(ast_strs):
ast = eval(ast_str)
interpret_ast(ast)

View File

@ -17,7 +17,7 @@ class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]

View File

@ -41,7 +41,6 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,
}}

View File

@ -24,18 +24,16 @@ class RawDiskBuffer(RawBufferMapped):
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
def shrink(self, arg):
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
def as_strided(self, arg):
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
assert len(arg)<2 or arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
offset = arg[0][0]*(prod(self.shape[1:]) if len(arg)>1 else 1)*self.dtype.itemsize
size = (arg[0][1]-arg[0][0]) * (prod(self.shape[1:]) if len(arg)>1 else 1)
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+(self.shape[1:] if len(arg)>1 else ()))
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf):
self._buf[0].seek(self.offset)
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.SHRINK: RawDiskBuffer.shrink, MovementOps.RESHAPE: RawDiskBuffer.reshape }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op)

View File

@ -24,5 +24,5 @@ class RawShmBuffer(RawBufferMapped):
def _buffer(self): return memoryview(self._buf)
# TODO: is this wrong?
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x }
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op)

View File

@ -16,12 +16,6 @@ def match_types(x, y, disallow_bool=False):
if disallow_bool and up == torch.bool: up = torch.float
return x.type(up), y.type(up)
def as_strided(x, arg):
if any(i < 0 for i in arg[1]):
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0])
return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2])
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
# TODO: torch.tensor should work here
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
@ -38,7 +32,6 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg),
MovementOps.AS_STRIDED: as_strided
}}
class RawTorchBuffer(RawBuffer):

View File

@ -97,14 +97,35 @@ class ShapeTracker:
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in self.views:
for i, v in enumerate(self.views):
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
if i: buffer_size = prod(self.views[i-1].shape) - real_offset
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
ordered_shape_strides, order = sort_by_strides(real_real_shape, strides)
to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))])
if strides:
if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),)))
for i, shape_stride in enumerate(ordered_shape_strides):
if i<len(ordered_shape_strides)-1 and shape_stride[1] < ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
remaining_buffer = ordered_shape_strides[i-1][1] if i>0 else buffer_size
to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer)))
to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1)))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer)))
to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1])))
ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1])
else:
to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1])))
to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))])
if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides)))))
to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides))))
if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))