1
0
Fork 0

remove val expansion (#539)

* remove val expansion

* types for all shapetracker functions:

* more typing

* add all the parens to the test

* more types

* fix tests

* very minor speedup
pull/542/head
George Hotz 2023-02-07 15:14:05 -06:00 committed by GitHub
parent 001cc96e25
commit aebe75d9a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 143 additions and 125 deletions

View File

@ -12,22 +12,22 @@ class DumbShapeTracker:
def shape(self):
return self.t.shape
def reshape(self, *new_shape):
def reshape(self, new_shape):
self.t = self.t.reshape(new_shape)
def permute(self, *axis):
def permute(self, axis):
self.t = np.transpose(self.t, axis)
def expand(self, *new_shape):
def expand(self, new_shape):
self.t = np.broadcast_to(self.t, new_shape)
def flip(self, *axis):
def flip(self, axis):
self.t = np.flip(self.t, axis)
def shrink(self, *arg):
def shrink(self, arg):
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
def stride(self, *arg):
def stride(self, arg):
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
def __getitem__(self, val):
@ -39,7 +39,7 @@ class DumbShapeTracker:
class TestZeroViewShapeTracker(unittest.TestCase):
def test_pad(self):
self.st = ShapeTracker((4, 4))
self.st.pad((1, 1), (1, 1))
self.st.pad(((1, 1), (1, 1)))
assert self.st.shape == (6,6)
compareZv = ZeroView((4,4), ((-1,5), (-1,5)))
assert self.st.views[1].expr == compareZv.expr
@ -47,88 +47,88 @@ class TestZeroViewShapeTracker(unittest.TestCase):
class TestComplexShapeTracker(unittest.TestCase):
def test_add_1s(self):
self.st = ShapeTracker((4, 4))
self.st.permute(1,0)
self.st.reshape(1,4,1,4,1)
self.st.permute((1,0))
self.st.reshape((1,4,1,4,1))
assert not self.st.contiguous
self.st.permute(0,3,2,1,4)
self.st.permute((0,3,2,1,4))
assert self.st.contiguous
def test_permute_1s_simple(self):
self.st = ShapeTracker((1, 16, 9,9))
self.st.permute(1,0,2,3)
self.st.permute((1,0,2,3))
assert self.st.contiguous
self.st = ShapeTracker((2, 16, 9,9))
self.st.permute(1,0,2,3)
self.st.permute((1,0,2,3))
assert not self.st.contiguous
def test_remove_1s_simple(self):
self.st = ShapeTracker((1, 16, 1, 1))
self.st.reshape(16,)
self.st.reshape((16,))
assert self.st.contiguous
def test_remove_1s(self):
self.st = ShapeTracker((1, 4, 1, 4, 1))
self.st.permute(0,3,2,1,4)
self.st.reshape(4,4)
self.st.permute((0,3,2,1,4))
self.st.reshape((4,4))
assert not self.st.contiguous
self.st.permute(1,0)
self.st.permute((1,0))
assert self.st.contiguous
def test_permute_reshape(self):
self.st = ShapeTracker((4, 4))
self.st.permute(1,0)
self.st.reshape(2, 2, 2, 2)
self.st.permute((1,0))
self.st.reshape((2, 2, 2, 2))
# TODO: should also be tested by test_super_complex
assert len(self.st.views) == 1
def test_factorize_split(self):
self.st = ShapeTracker((4, 4))
self.st.permute(1,0)
self.st.reshape(2, 2, 2, 2)
self.st.permute(2,3,0,1)
self.st.permute((1,0))
self.st.reshape((2, 2, 2, 2))
self.st.permute((2,3,0,1))
assert self.st.contiguous
def test_factorize_combine(self):
self.st = ShapeTracker((4, 4, 4))
self.st.permute(2, 0, 1)
self.st.reshape(4, 16)
self.st.permute(1, 0)
self.st.permute((2, 0, 1))
self.st.reshape((4, 16))
self.st.permute((1, 0))
assert self.st.contiguous
def test_factorize_combine_add_ones(self):
self.st = ShapeTracker((4, 4, 4))
self.st.permute(2, 0, 1)
self.st.reshape(4, 16, 1, 1)
self.st.permute(1, 0, 2, 3)
self.st.permute((2, 0, 1))
self.st.reshape((4, 16, 1, 1))
self.st.permute((1, 0, 2, 3))
assert self.st.contiguous
def test_fancy_factorize(self):
self.st = ShapeTracker((32, 3, 3, 1))
self.st.strided(*zip((32, 3, 3, 1), (1, 4096, 32, 1)))
self.st.reshape(*(8, 4, 3, 3))
self.st.strided(tuple(zip((32, 3, 3, 1), (1, 4096, 32, 1))))
self.st.reshape((8, 4, 3, 3))
assert len(self.st.views) == 1
def test_super_complex_2_fail(self):
self.st = ShapeTracker((4, 4, 4))
self.st.permute(2, 0, 1)
self.st.reshape(16, 4)
self.st.permute((2, 0, 1))
self.st.reshape((16, 4))
assert len(self.st.views) != 1
def test_work(self):
self.st = ShapeTracker((64, 1024, 4))
self.st.reshape(1, 64, 128, 32)
self.st.permute(0, 3, 1, 2)
self.st.reshape(1, 32, 1, 64, 128)
self.st.permute(0, 3, 4, 1, 2)
self.st.reshape((1, 64, 128, 32))
self.st.permute((0, 3, 1, 2))
self.st.reshape((1, 32, 1, 64, 128))
self.st.permute((0, 3, 4, 1, 2))
assert self.st.contiguous
def test_work2(self):
self.st = ShapeTracker((64, 1024, 4))
self.st.reshape(1, 64, 128, 32)
self.st.permute(0, 3, 1, 2)
self.st.reshape(1, 1, 32, 64, 128)
self.st.permute(0, 3, 4, 1, 2)
self.st.reshape(64, 1024, 4)
self.st.reshape((1, 64, 128, 32))
self.st.permute((0, 3, 1, 2))
self.st.reshape((1, 1, 32, 64, 128))
self.st.permute((0, 3, 4, 1, 2))
self.st.reshape((64, 1024, 4))
print(self.st.views)
assert self.st.contiguous
@ -137,35 +137,35 @@ class TestSingleShapeTracker(unittest.TestCase):
self.st = ShapeTracker((7,4))
def test_reshape(self):
self.st.reshape(7,1,4)
self.st.reshape((7,1,4))
assert self.st.contiguous
def test_permute(self):
self.st.permute(1,0)
self.st.permute((1,0))
assert not self.st.contiguous
def test_shrink(self):
self.st.shrink((1,2), (0,4))
self.st.shrink(((1,2), (0,4)))
assert not self.st.contiguous
def test_double_permute(self):
self.st.permute(1,0)
self.st.permute(1,0)
self.st.permute((1,0))
self.st.permute((1,0))
assert self.st.contiguous
def test_reshape_permute(self):
self.st.reshape(7,1,4)
self.st.permute(0,1,2)
self.st.reshape((7,1,4))
self.st.permute((0,1,2))
assert self.st.contiguous
def test_reshape_permute_yes(self):
self.st.reshape(7,1,4)
self.st.permute(0,2,1)
self.st.reshape((7,1,4))
self.st.permute((0,2,1))
assert self.st.contiguous
def test_reshape_permute_no(self):
self.st.reshape(4,7)
self.st.permute(1,0)
self.st.reshape((4,7))
self.st.permute((1,0))
assert not self.st.contiguous
def shapetracker_getitem(st, val):
@ -191,72 +191,72 @@ class TestShapeTracker(unittest.TestCase):
def test_simple_split(self):
self.test_permute()
self.apply(lambda x: x.reshape(prod(self.st.shape)))
self.apply(lambda x: x.reshape((prod(self.st.shape), )))
def test_reshape(self):
assert self.st.shape == self.dt.shape
new_shape = self.st.shape[::-1]
self.apply(lambda x: x.reshape(*new_shape))
self.apply(lambda x: x.reshape(new_shape))
def test_permute(self):
assert self.st.shape == self.dt.shape
if len(self.st.shape) == 2: self.apply(lambda x: x.permute(1,0))
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute(2,0,1))
if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0)))
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1)))
def test_reshape_with_1(self):
assert self.st.shape == self.dt.shape
new_shape = [self.st.shape[0], 1, self.st.shape[1]]
self.apply(lambda x: x.reshape(*new_shape))
new_shape = (self.st.shape[0], 1, self.st.shape[1])
self.apply(lambda x: x.reshape(new_shape))
def test_expand(self):
self.test_reshape_with_1()
new_shape = list(self.st.shape)
new_shape[1] = 2
self.apply(lambda x: x.expand(*new_shape))
self.apply(lambda x: x.expand(tuple(new_shape)))
def test_flip_0(self):
self.apply(lambda x: x.flip(0))
self.apply(lambda x: x.flip((0,)))
def test_flip_1(self):
self.apply(lambda x: x.flip(1))
self.apply(lambda x: x.flip((1,)))
def test_flip_01(self):
self.apply(lambda x: x.flip(0,1))
self.apply(lambda x: x.flip((0,1)))
def test_slice_0(self):
self.apply(lambda x: x.shrink((1, x.shape[0]), (0, x.shape[1])))
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
def test_slice_1(self):
self.apply(lambda x: x.shrink((0, x.shape[0]), (1, x.shape[1])))
self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1]))))
def test_slice_1c1(self):
self.apply(lambda x: x.shrink((0, 1), (0, 1)))
self.apply(lambda x: x.shrink(((0, 1), (0, 1))))
def test_slice_1c2(self):
self.apply(lambda x: x.shrink((1, 2), (1, 2)))
self.apply(lambda x: x.shrink(((1, 2), (1, 2))))
def test_double_permute(self):
self.apply(lambda x: x.permute(1, 0))
self.apply(lambda x: x.permute(1, 0))
self.apply(lambda x: x.permute((1, 0)))
self.apply(lambda x: x.permute((1, 0)))
def test_slice_permute(self):
self.apply(lambda x: x.shrink((0, 2), (2, 4)))
self.apply(lambda x: x.permute(1, 0))
self.apply(lambda x: x.shrink(((0, 2), (2, 4))))
self.apply(lambda x: x.permute((1, 0)))
def test_slice_expand(self):
self.apply(lambda x: x.shrink((0, 2), (3, 4)))
self.apply(lambda x: x.expand(2, 10))
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
self.apply(lambda x: x.expand((2, 10)))
def test_double_stride(self):
self.apply(lambda x: x.stride(1, 2))
self.apply(lambda x: x.stride(2, 1))
self.apply(lambda x: x.stride((1, 2)))
self.apply(lambda x: x.stride((2, 1)))
def test_stride(self): self.apply(lambda x: x.stride(2,1))
def test_stride_int(self): self.apply(lambda x: x.stride(1,2))
def test_stride_2(self): self.apply(lambda x: x.stride(2,2))
def test_stride_n(self): self.apply(lambda x: x.stride(-2,1))
def test_stride_int_n(self): self.apply(lambda x: x.stride(-1,2))
def test_stride_2_n(self): self.apply(lambda x: x.stride(-2,-2))
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
def test_reshape_then_permute(self):
self.test_reshape()

View File

@ -130,7 +130,7 @@ class ASTKernel:
else:
rets[j].append((shapes[j][i], strides[j][i]))
for i,x in enumerate(rets): self.sts[i].reshape(*[y[0] for y in x])
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
# this should be aware of the three parts to the shape
@ -139,8 +139,8 @@ class ASTKernel:
# * the size outputted by each kernel
def reshape_and_permute(self, new_shape_fxn, axis):
for st in self.sts:
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(st.shape))
if axis is not None: st.permute(*axis)
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st.permute(tuple(axis))
# drops the final dimension
def upcast(self, allow_float4=True):

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict
from copy import copy
from typing import Optional, Tuple, Union, List, Dict, Any
import sys, weakref
from tinygrad.helpers import ConvArgs, get_available_llops, prod
from tinygrad.shape import ShapeTracker
@ -161,7 +160,7 @@ class LazyBuffer:
return self
reduce = list(enumerate(zip(self.shape, new_shape)))
# move the reduce axes to the end
x = self.movement_op(MovementOps.PERMUTE, [i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
x = self.movement_op(MovementOps.PERMUTE, tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]))
new_tmp_shape = tuple([n for _,(s,n) in reduce if s == n] + [n for _,(s,n) in reduce if s != n])
# NOTE: this reshape can only move around 1s
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
@ -169,12 +168,14 @@ class LazyBuffer:
# syntactic sugar around PAD and SHRINK
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
def slice(self:LazyBuffer, arg):
padding = [(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg)]
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg))
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
def movement_op(self:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
def movement_op(self:LazyBuffer, op:MovementOps, arg : Tuple[Any, ...]) -> LazyBuffer:
# very instant nop
if op == MovementOps.RESHAPE and self.shape == arg: return self
# TODO: look into why that copy is needed
arg = tuple(copy(arg))
local_st = ShapeTracker(self.shape).movement_op(op, arg)
# instant nops
@ -232,16 +233,16 @@ class LazyBuffer:
# hack for non multiples of 4 on C.cin
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
to_add = 4 - (C.cin % 4)
w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))])
w = w.movement_op(MovementOps.PAD, tuple((0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))])
x = x.movement_op(MovementOps.PAD, tuple((0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))))
C = C._replace(cin = C.cin + to_add)
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
# hack for non multiples of 4 on C.rcout
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
added_output_channels = 4 - (C.rcout % 4)
w = w.movement_op(MovementOps.PAD, [(0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))])
w = w.movement_op(MovementOps.PAD, tuple((0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))))
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
# packed
@ -307,7 +308,7 @@ class LazyBuffer:
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
ret = ret.movement_op(MovementOps.SHRINK, [(0, s-added_output_channels) if i == 4 else (0, s) for i,s in enumerate(ret.shape)])
ret = ret.movement_op(MovementOps.SHRINK, tuple((0, s-added_output_channels) if i == 4 else (0, s) for i,s in enumerate(ret.shape)))
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))

View File

@ -258,8 +258,8 @@ class CLASTKernel(ASTKernel):
# early ast
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
if self.reduceop:
full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
full_shape_candidates = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
full_shape : Tuple[int, ...] = self.sts[0].shape if len(full_shape_candidates) == 0 else full_shape_candidates[0]
acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets()
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
@ -349,7 +349,7 @@ class GPUBuffer(ExplicitExecAST):
def toCPU(self) -> np.ndarray:
cl_buf = self.unary_op(UnaryOps.NOOP) if not self.st.contiguous or prod(self._base_shape) != prod(self.shape) else self
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, tuple(list(self.shape)+[1])).unary_op(UnaryOps.NOOP)
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
data = np.empty(self.shape, dtype=np.float32)
cl_buf._buf.copyout(data)

View File

@ -249,12 +249,14 @@ class LLVMBuffer(ExplicitExecAST):
def get_idxs(builder, idx, buf_index):
idx_offsets = [0]
"""
for axis in kernel_output_axis:
new_idx_offsets = []
for s in range(k.shapes[buf_index][axis]):
for i in idx_offsets:
new_idx_offsets.append(i + s * k.strides[buf_index][axis])
idx_offsets = new_idx_offsets
"""
return [builder.add(idx, int_const(i)) for i in idx_offsets]
# *** llvm specific below this line ***

View File

@ -19,7 +19,7 @@ class CL:
cl_queue : Optional[cl.CommandQueue] = None
def __init__(self):
if CL.cl_queue is not None: return # already initted
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])

View File

@ -20,8 +20,8 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
return ret
class View:
def __init__(self, shape:Union[Tuple[int, ...],List[int]], strides:Union[Tuple[int, ...],List[int]], offset:int=0):
self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
self.shape, self.strides, self.offset = shape, strides, offset
self.shape_strides = to_shape_strides(self.shape, self.strides)
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"
@ -61,6 +61,8 @@ class ZeroView:
@property
def contiguous(self): return False
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't have expr_idxs")
def expr_node(self, valid, idx):
expr, acc = [valid] if valid is not None else [], 1
for s,ns,(x,y) in list(zip(self.old_shape, self.shape, self.arg))[::-1]:
@ -95,7 +97,7 @@ class ShapeTracker:
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
def copy(self): return ShapeTracker(self.shape, self.views[:])
def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:])
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous
@ -132,13 +134,16 @@ class ShapeTracker:
else: return f"idx={idx}"
#def expr(self): return ';'.join([v.expr for v in self.views[::-1] if v.expr != 'idx=idx' and v.expr != 'valid=valid'])
def movement_op(self, op, arg): return getattr(self, str(op).split(".")[1].lower())(*arg)
def needs_valid(self): return any(isinstance(v, ZeroView) for v in self.views)
def movement_op(self, op, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker:
return getattr(self, str(op).split(".")[1].lower())(arg)
def needs_valid(self) -> bool:
return any(isinstance(v, ZeroView) for v in self.views)
# TODO: do we really need this for conv?
# if we replace, confirm the ops taken fold into one view
def strided(self, *arg):
view = View([x[0] for x in arg], [x[1] for x in arg])
def strided(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
assert isinstance(arg, tuple)
view = View(tuple(x[0] for x in arg), tuple(x[1] for x in arg))
# TODO: this does not always require a new view if non contiguous
if self.views[-1].contiguous:
self.views[-1] = view
@ -146,22 +151,24 @@ class ShapeTracker:
self.views.append(view)
return self
def reshape(self, *new_shape):
def reshape(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
assert isinstance(new_shape, tuple)
if self.shape == new_shape: return self
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# check if this is adding or removing 1s (only)
if tuple([x for x in self.shape if x != 1]) == tuple([x for x in new_shape if x != 1]):
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides = [0 if x == 1 else old_strides.pop(0) for x in new_shape]
self.views[-1] = View(new_shape, new_strides, self.offset)
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
return self
# check if the new dimensions factorize from the old ones
# NOTE: if you don't make a copy here, the list is popped in the lrucache
min_shape_strides = to_shape_strides(self.shape, self.strides)[:]
curr_dim, curr_stride = min_shape_strides.pop(0)
new_strides = []
new_strides : List[int] = []
for s in new_shape:
if curr_dim%s == 0:
curr_dim //= s
@ -178,7 +185,7 @@ class ShapeTracker:
break # didn't factorize
if len(new_shape) == len(new_strides):
self.views[-1] = View(new_shape, new_strides, self.offset)
self.views[-1] = View(new_shape, tuple(new_strides), self.offset)
return self
view = View(new_shape, strides_for_shape(new_shape))
@ -189,46 +196,52 @@ class ShapeTracker:
self.views.append(view)
return self
def permute(self, *axis):
def permute(self, axis : Tuple[int, ...]) -> ShapeTracker:
assert isinstance(axis, tuple)
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset)
self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset)
return self
# TODO: this is a special case of slice with strides, remove it
# though it's nice that it can't change size
def flip(self, *axis): return self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
def flip(self, axis : Tuple[int, ...]) -> ShapeTracker:
return self.stride(tuple(-1 if i in axis else 1 for i in range(len((self.shape)))))
# *** under this line are not invertible ***
# TODO: take this functionality out of slice
def pad(self, *arg):
def pad(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
assert isinstance(arg, tuple)
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
return self.shrink(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
return self.shrink(tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)))
# TODO: take the pad functionality out of shrink
def shrink(self, *arg):
def shrink(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
assert isinstance(arg, tuple)
assert len(arg) == len(self.shape)
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
zeroview = ZeroView(self.shape, arg)
self.views[-1] = View([y-x for x,y in arg], self.strides, self.offset+offset)
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
if zeroview.expr != "valid=valid":
# if we add a ZeroView, we add another (stock) view also for modding
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
return self
def expand(self, *new_shape):
def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
assert isinstance(new_shape, tuple)
assert all(isinstance(x, int) for x in new_shape)
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
strides = [s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape))]
strides : Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
self.views[-1] = View(new_shape, strides, self.offset)
return self
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
def stride(self, *mul):
def stride(self, mul : Tuple[int, ...]) -> ShapeTracker:
assert isinstance(mul, tuple)
assert all(isinstance(x, int) for x in mul)
strides = [z*m for z,m in zip(self.strides, mul)]
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
strides = tuple(z*m for z,m in zip(self.strides, mul))
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
self.views[-1] = View(new_shape, strides, self.offset + offset)
return self

View File

@ -96,6 +96,7 @@ class NumNode(Node):
self.b, self.min, self.max = num, num, num
class OpNode(Node):
op : str
def __init__(self, a:Node, b:int):
self.a, self.b = a,b
self.min, self.max = self.minmax(a,b)
@ -105,6 +106,7 @@ class OpNode(Node):
return f"({self.a}{self.op}{self.b})"
class RedNode(Node):
op : str
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = self.minmax(nodes)

View File

@ -168,12 +168,12 @@ class Tensor:
dim = (dim + len(self.shape)) if dim < 0 else dim
for y in args:
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
args = [self] + list(args)
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in args)]
slc = [[(0, s) for s in self.shape] for _ in args]
catargs = [self] + list(args)
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in catargs)]
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(args, slc)])
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):