remove val expansion (#539)
* remove val expansion * types for all shapetracker functions: * more typing * add all the parens to the test * more types * fix tests * very minor speeduppull/542/head
parent
001cc96e25
commit
aebe75d9a2
|
@ -12,22 +12,22 @@ class DumbShapeTracker:
|
|||
def shape(self):
|
||||
return self.t.shape
|
||||
|
||||
def reshape(self, *new_shape):
|
||||
def reshape(self, new_shape):
|
||||
self.t = self.t.reshape(new_shape)
|
||||
|
||||
def permute(self, *axis):
|
||||
def permute(self, axis):
|
||||
self.t = np.transpose(self.t, axis)
|
||||
|
||||
def expand(self, *new_shape):
|
||||
def expand(self, new_shape):
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
|
||||
def flip(self, *axis):
|
||||
def flip(self, axis):
|
||||
self.t = np.flip(self.t, axis)
|
||||
|
||||
def shrink(self, *arg):
|
||||
def shrink(self, arg):
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
|
||||
def stride(self, *arg):
|
||||
def stride(self, arg):
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
|
||||
def __getitem__(self, val):
|
||||
|
@ -39,7 +39,7 @@ class DumbShapeTracker:
|
|||
class TestZeroViewShapeTracker(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st.pad((1, 1), (1, 1))
|
||||
self.st.pad(((1, 1), (1, 1)))
|
||||
assert self.st.shape == (6,6)
|
||||
compareZv = ZeroView((4,4), ((-1,5), (-1,5)))
|
||||
assert self.st.views[1].expr == compareZv.expr
|
||||
|
@ -47,88 +47,88 @@ class TestZeroViewShapeTracker(unittest.TestCase):
|
|||
class TestComplexShapeTracker(unittest.TestCase):
|
||||
def test_add_1s(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st.permute(1,0)
|
||||
self.st.reshape(1,4,1,4,1)
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((1,4,1,4,1))
|
||||
assert not self.st.contiguous
|
||||
self.st.permute(0,3,2,1,4)
|
||||
self.st.permute((0,3,2,1,4))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_1s_simple(self):
|
||||
self.st = ShapeTracker((1, 16, 9,9))
|
||||
self.st.permute(1,0,2,3)
|
||||
self.st.permute((1,0,2,3))
|
||||
assert self.st.contiguous
|
||||
self.st = ShapeTracker((2, 16, 9,9))
|
||||
self.st.permute(1,0,2,3)
|
||||
self.st.permute((1,0,2,3))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_remove_1s_simple(self):
|
||||
self.st = ShapeTracker((1, 16, 1, 1))
|
||||
self.st.reshape(16,)
|
||||
self.st.reshape((16,))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_remove_1s(self):
|
||||
self.st = ShapeTracker((1, 4, 1, 4, 1))
|
||||
self.st.permute(0,3,2,1,4)
|
||||
self.st.reshape(4,4)
|
||||
self.st.permute((0,3,2,1,4))
|
||||
self.st.reshape((4,4))
|
||||
assert not self.st.contiguous
|
||||
self.st.permute(1,0)
|
||||
self.st.permute((1,0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute_reshape(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st.permute(1,0)
|
||||
self.st.reshape(2, 2, 2, 2)
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
# TODO: should also be tested by test_super_complex
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_factorize_split(self):
|
||||
self.st = ShapeTracker((4, 4))
|
||||
self.st.permute(1,0)
|
||||
self.st.reshape(2, 2, 2, 2)
|
||||
self.st.permute(2,3,0,1)
|
||||
self.st.permute((1,0))
|
||||
self.st.reshape((2, 2, 2, 2))
|
||||
self.st.permute((2,3,0,1))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st.permute(2, 0, 1)
|
||||
self.st.reshape(4, 16)
|
||||
self.st.permute(1, 0)
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16))
|
||||
self.st.permute((1, 0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_factorize_combine_add_ones(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st.permute(2, 0, 1)
|
||||
self.st.reshape(4, 16, 1, 1)
|
||||
self.st.permute(1, 0, 2, 3)
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((4, 16, 1, 1))
|
||||
self.st.permute((1, 0, 2, 3))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_fancy_factorize(self):
|
||||
self.st = ShapeTracker((32, 3, 3, 1))
|
||||
self.st.strided(*zip((32, 3, 3, 1), (1, 4096, 32, 1)))
|
||||
self.st.reshape(*(8, 4, 3, 3))
|
||||
self.st.strided(tuple(zip((32, 3, 3, 1), (1, 4096, 32, 1))))
|
||||
self.st.reshape((8, 4, 3, 3))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_super_complex_2_fail(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st.permute(2, 0, 1)
|
||||
self.st.reshape(16, 4)
|
||||
self.st.permute((2, 0, 1))
|
||||
self.st.reshape((16, 4))
|
||||
assert len(self.st.views) != 1
|
||||
|
||||
def test_work(self):
|
||||
self.st = ShapeTracker((64, 1024, 4))
|
||||
self.st.reshape(1, 64, 128, 32)
|
||||
self.st.permute(0, 3, 1, 2)
|
||||
self.st.reshape(1, 32, 1, 64, 128)
|
||||
self.st.permute(0, 3, 4, 1, 2)
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 32, 1, 64, 128))
|
||||
self.st.permute((0, 3, 4, 1, 2))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_work2(self):
|
||||
self.st = ShapeTracker((64, 1024, 4))
|
||||
self.st.reshape(1, 64, 128, 32)
|
||||
self.st.permute(0, 3, 1, 2)
|
||||
self.st.reshape(1, 1, 32, 64, 128)
|
||||
self.st.permute(0, 3, 4, 1, 2)
|
||||
self.st.reshape(64, 1024, 4)
|
||||
self.st.reshape((1, 64, 128, 32))
|
||||
self.st.permute((0, 3, 1, 2))
|
||||
self.st.reshape((1, 1, 32, 64, 128))
|
||||
self.st.permute((0, 3, 4, 1, 2))
|
||||
self.st.reshape((64, 1024, 4))
|
||||
print(self.st.views)
|
||||
assert self.st.contiguous
|
||||
|
||||
|
@ -137,35 +137,35 @@ class TestSingleShapeTracker(unittest.TestCase):
|
|||
self.st = ShapeTracker((7,4))
|
||||
|
||||
def test_reshape(self):
|
||||
self.st.reshape(7,1,4)
|
||||
self.st.reshape((7,1,4))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_permute(self):
|
||||
self.st.permute(1,0)
|
||||
self.st.permute((1,0))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_shrink(self):
|
||||
self.st.shrink((1,2), (0,4))
|
||||
self.st.shrink(((1,2), (0,4)))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_double_permute(self):
|
||||
self.st.permute(1,0)
|
||||
self.st.permute(1,0)
|
||||
self.st.permute((1,0))
|
||||
self.st.permute((1,0))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute(self):
|
||||
self.st.reshape(7,1,4)
|
||||
self.st.permute(0,1,2)
|
||||
self.st.reshape((7,1,4))
|
||||
self.st.permute((0,1,2))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute_yes(self):
|
||||
self.st.reshape(7,1,4)
|
||||
self.st.permute(0,2,1)
|
||||
self.st.reshape((7,1,4))
|
||||
self.st.permute((0,2,1))
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_permute_no(self):
|
||||
self.st.reshape(4,7)
|
||||
self.st.permute(1,0)
|
||||
self.st.reshape((4,7))
|
||||
self.st.permute((1,0))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
|
@ -191,72 +191,72 @@ class TestShapeTracker(unittest.TestCase):
|
|||
|
||||
def test_simple_split(self):
|
||||
self.test_permute()
|
||||
self.apply(lambda x: x.reshape(prod(self.st.shape)))
|
||||
self.apply(lambda x: x.reshape((prod(self.st.shape), )))
|
||||
|
||||
def test_reshape(self):
|
||||
assert self.st.shape == self.dt.shape
|
||||
new_shape = self.st.shape[::-1]
|
||||
self.apply(lambda x: x.reshape(*new_shape))
|
||||
self.apply(lambda x: x.reshape(new_shape))
|
||||
|
||||
def test_permute(self):
|
||||
assert self.st.shape == self.dt.shape
|
||||
if len(self.st.shape) == 2: self.apply(lambda x: x.permute(1,0))
|
||||
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute(2,0,1))
|
||||
if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0)))
|
||||
elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1)))
|
||||
|
||||
def test_reshape_with_1(self):
|
||||
assert self.st.shape == self.dt.shape
|
||||
new_shape = [self.st.shape[0], 1, self.st.shape[1]]
|
||||
self.apply(lambda x: x.reshape(*new_shape))
|
||||
new_shape = (self.st.shape[0], 1, self.st.shape[1])
|
||||
self.apply(lambda x: x.reshape(new_shape))
|
||||
|
||||
def test_expand(self):
|
||||
self.test_reshape_with_1()
|
||||
new_shape = list(self.st.shape)
|
||||
new_shape[1] = 2
|
||||
self.apply(lambda x: x.expand(*new_shape))
|
||||
self.apply(lambda x: x.expand(tuple(new_shape)))
|
||||
|
||||
def test_flip_0(self):
|
||||
self.apply(lambda x: x.flip(0))
|
||||
self.apply(lambda x: x.flip((0,)))
|
||||
|
||||
def test_flip_1(self):
|
||||
self.apply(lambda x: x.flip(1))
|
||||
self.apply(lambda x: x.flip((1,)))
|
||||
|
||||
def test_flip_01(self):
|
||||
self.apply(lambda x: x.flip(0,1))
|
||||
self.apply(lambda x: x.flip((0,1)))
|
||||
|
||||
def test_slice_0(self):
|
||||
self.apply(lambda x: x.shrink((1, x.shape[0]), (0, x.shape[1])))
|
||||
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
|
||||
|
||||
def test_slice_1(self):
|
||||
self.apply(lambda x: x.shrink((0, x.shape[0]), (1, x.shape[1])))
|
||||
self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1]))))
|
||||
|
||||
def test_slice_1c1(self):
|
||||
self.apply(lambda x: x.shrink((0, 1), (0, 1)))
|
||||
self.apply(lambda x: x.shrink(((0, 1), (0, 1))))
|
||||
|
||||
def test_slice_1c2(self):
|
||||
self.apply(lambda x: x.shrink((1, 2), (1, 2)))
|
||||
self.apply(lambda x: x.shrink(((1, 2), (1, 2))))
|
||||
|
||||
def test_double_permute(self):
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
|
||||
def test_slice_permute(self):
|
||||
self.apply(lambda x: x.shrink((0, 2), (2, 4)))
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
self.apply(lambda x: x.shrink(((0, 2), (2, 4))))
|
||||
self.apply(lambda x: x.permute((1, 0)))
|
||||
|
||||
def test_slice_expand(self):
|
||||
self.apply(lambda x: x.shrink((0, 2), (3, 4)))
|
||||
self.apply(lambda x: x.expand(2, 10))
|
||||
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
|
||||
self.apply(lambda x: x.expand((2, 10)))
|
||||
|
||||
def test_double_stride(self):
|
||||
self.apply(lambda x: x.stride(1, 2))
|
||||
self.apply(lambda x: x.stride(2, 1))
|
||||
self.apply(lambda x: x.stride((1, 2)))
|
||||
self.apply(lambda x: x.stride((2, 1)))
|
||||
|
||||
def test_stride(self): self.apply(lambda x: x.stride(2,1))
|
||||
def test_stride_int(self): self.apply(lambda x: x.stride(1,2))
|
||||
def test_stride_2(self): self.apply(lambda x: x.stride(2,2))
|
||||
def test_stride_n(self): self.apply(lambda x: x.stride(-2,1))
|
||||
def test_stride_int_n(self): self.apply(lambda x: x.stride(-1,2))
|
||||
def test_stride_2_n(self): self.apply(lambda x: x.stride(-2,-2))
|
||||
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
|
||||
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
|
||||
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
|
||||
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
|
||||
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
|
||||
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
|
||||
|
||||
def test_reshape_then_permute(self):
|
||||
self.test_reshape()
|
||||
|
|
|
@ -130,7 +130,7 @@ class ASTKernel:
|
|||
else:
|
||||
rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(*[y[0] for y in x])
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
|
||||
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
|
||||
|
||||
# this should be aware of the three parts to the shape
|
||||
|
@ -139,8 +139,8 @@ class ASTKernel:
|
|||
# * the size outputted by each kernel
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(st.shape))
|
||||
if axis is not None: st.permute(*axis)
|
||||
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st.permute(tuple(axis))
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self, allow_float4=True):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict
|
||||
from copy import copy
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any
|
||||
import sys, weakref
|
||||
from tinygrad.helpers import ConvArgs, get_available_llops, prod
|
||||
from tinygrad.shape import ShapeTracker
|
||||
|
@ -161,7 +160,7 @@ class LazyBuffer:
|
|||
return self
|
||||
reduce = list(enumerate(zip(self.shape, new_shape)))
|
||||
# move the reduce axes to the end
|
||||
x = self.movement_op(MovementOps.PERMUTE, [i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
x = self.movement_op(MovementOps.PERMUTE, tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]))
|
||||
new_tmp_shape = tuple([n for _,(s,n) in reduce if s == n] + [n for _,(s,n) in reduce if s != n])
|
||||
# NOTE: this reshape can only move around 1s
|
||||
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
|
||||
|
@ -169,12 +168,14 @@ class LazyBuffer:
|
|||
# syntactic sugar around PAD and SHRINK
|
||||
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
|
||||
def slice(self:LazyBuffer, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg)]
|
||||
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg))
|
||||
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
|
||||
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg : Tuple[Any, ...]) -> LazyBuffer:
|
||||
# very instant nop
|
||||
if op == MovementOps.RESHAPE and self.shape == arg: return self
|
||||
|
||||
# TODO: look into why that copy is needed
|
||||
arg = tuple(copy(arg))
|
||||
local_st = ShapeTracker(self.shape).movement_op(op, arg)
|
||||
|
||||
# instant nops
|
||||
|
@ -232,16 +233,16 @@ class LazyBuffer:
|
|||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))])
|
||||
w = w.movement_op(MovementOps.PAD, tuple((0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))))
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))])
|
||||
x = x.movement_op(MovementOps.PAD, tuple((0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))))
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
w = w.movement_op(MovementOps.PAD, [(0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))])
|
||||
w = w.movement_op(MovementOps.PAD, tuple((0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))))
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
|
||||
# packed
|
||||
|
@ -307,7 +308,7 @@ class LazyBuffer:
|
|||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
ret = ret.movement_op(MovementOps.SHRINK, [(0, s-added_output_channels) if i == 4 else (0, s) for i,s in enumerate(ret.shape)])
|
||||
ret = ret.movement_op(MovementOps.SHRINK, tuple((0, s-added_output_channels) if i == 4 else (0, s) for i,s in enumerate(ret.shape)))
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
|
|
|
@ -258,8 +258,8 @@ class CLASTKernel(ASTKernel):
|
|||
# early ast
|
||||
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
|
||||
if self.reduceop:
|
||||
full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
|
||||
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
|
||||
full_shape_candidates = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
|
||||
full_shape : Tuple[int, ...] = self.sts[0].shape if len(full_shape_candidates) == 0 else full_shape_candidates[0]
|
||||
|
||||
acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets()
|
||||
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
|
||||
|
@ -349,7 +349,7 @@ class GPUBuffer(ExplicitExecAST):
|
|||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
cl_buf = self.unary_op(UnaryOps.NOOP) if not self.st.contiguous or prod(self._base_shape) != prod(self.shape) else self
|
||||
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
|
||||
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, tuple(list(self.shape)+[1])).unary_op(UnaryOps.NOOP)
|
||||
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl_buf._buf.copyout(data)
|
||||
|
|
|
@ -249,12 +249,14 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
|
||||
def get_idxs(builder, idx, buf_index):
|
||||
idx_offsets = [0]
|
||||
"""
|
||||
for axis in kernel_output_axis:
|
||||
new_idx_offsets = []
|
||||
for s in range(k.shapes[buf_index][axis]):
|
||||
for i in idx_offsets:
|
||||
new_idx_offsets.append(i + s * k.strides[buf_index][axis])
|
||||
idx_offsets = new_idx_offsets
|
||||
"""
|
||||
return [builder.add(idx, int_const(i)) for i in idx_offsets]
|
||||
|
||||
# *** llvm specific below this line ***
|
||||
|
|
|
@ -19,7 +19,7 @@ class CL:
|
|||
cl_queue : Optional[cl.CommandQueue] = None
|
||||
def __init__(self):
|
||||
if CL.cl_queue is not None: return # already initted
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: # settle for CPU
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
|
||||
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
|
||||
|
|
|
@ -20,8 +20,8 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
|||
return ret
|
||||
|
||||
class View:
|
||||
def __init__(self, shape:Union[Tuple[int, ...],List[int]], strides:Union[Tuple[int, ...],List[int]], offset:int=0):
|
||||
self.shape, self.strides, self.offset = tuple(shape), tuple(strides), offset
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
|
||||
self.shape, self.strides, self.offset = shape, strides, offset
|
||||
self.shape_strides = to_shape_strides(self.shape, self.strides)
|
||||
|
||||
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"
|
||||
|
@ -61,6 +61,8 @@ class ZeroView:
|
|||
@property
|
||||
def contiguous(self): return False
|
||||
|
||||
def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't have expr_idxs")
|
||||
|
||||
def expr_node(self, valid, idx):
|
||||
expr, acc = [valid] if valid is not None else [], 1
|
||||
for s,ns,(x,y) in list(zip(self.old_shape, self.shape, self.arg))[::-1]:
|
||||
|
@ -95,7 +97,7 @@ class ShapeTracker:
|
|||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
|
||||
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
|
||||
def copy(self): return ShapeTracker(self.shape, self.views[:])
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:])
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous
|
||||
|
@ -132,13 +134,16 @@ class ShapeTracker:
|
|||
else: return f"idx={idx}"
|
||||
|
||||
#def expr(self): return ';'.join([v.expr for v in self.views[::-1] if v.expr != 'idx=idx' and v.expr != 'valid=valid'])
|
||||
def movement_op(self, op, arg): return getattr(self, str(op).split(".")[1].lower())(*arg)
|
||||
def needs_valid(self): return any(isinstance(v, ZeroView) for v in self.views)
|
||||
def movement_op(self, op, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker:
|
||||
return getattr(self, str(op).split(".")[1].lower())(arg)
|
||||
def needs_valid(self) -> bool:
|
||||
return any(isinstance(v, ZeroView) for v in self.views)
|
||||
|
||||
# TODO: do we really need this for conv?
|
||||
# if we replace, confirm the ops taken fold into one view
|
||||
def strided(self, *arg):
|
||||
view = View([x[0] for x in arg], [x[1] for x in arg])
|
||||
def strided(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple)
|
||||
view = View(tuple(x[0] for x in arg), tuple(x[1] for x in arg))
|
||||
# TODO: this does not always require a new view if non contiguous
|
||||
if self.views[-1].contiguous:
|
||||
self.views[-1] = view
|
||||
|
@ -146,22 +151,24 @@ class ShapeTracker:
|
|||
self.views.append(view)
|
||||
return self
|
||||
|
||||
def reshape(self, *new_shape):
|
||||
def reshape(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(new_shape, tuple)
|
||||
if self.shape == new_shape: return self
|
||||
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
if tuple([x for x in self.shape if x != 1]) == tuple([x for x in new_shape if x != 1]):
|
||||
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
|
||||
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
|
||||
new_strides = [0 if x == 1 else old_strides.pop(0) for x in new_shape]
|
||||
self.views[-1] = View(new_shape, new_strides, self.offset)
|
||||
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
|
||||
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
|
||||
return self
|
||||
|
||||
# check if the new dimensions factorize from the old ones
|
||||
# NOTE: if you don't make a copy here, the list is popped in the lrucache
|
||||
min_shape_strides = to_shape_strides(self.shape, self.strides)[:]
|
||||
curr_dim, curr_stride = min_shape_strides.pop(0)
|
||||
new_strides = []
|
||||
new_strides : List[int] = []
|
||||
for s in new_shape:
|
||||
if curr_dim%s == 0:
|
||||
curr_dim //= s
|
||||
|
@ -178,7 +185,7 @@ class ShapeTracker:
|
|||
break # didn't factorize
|
||||
|
||||
if len(new_shape) == len(new_strides):
|
||||
self.views[-1] = View(new_shape, new_strides, self.offset)
|
||||
self.views[-1] = View(new_shape, tuple(new_strides), self.offset)
|
||||
return self
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
|
@ -189,46 +196,52 @@ class ShapeTracker:
|
|||
self.views.append(view)
|
||||
return self
|
||||
|
||||
def permute(self, *axis):
|
||||
def permute(self, axis : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(axis, tuple)
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset)
|
||||
self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset)
|
||||
return self
|
||||
|
||||
# TODO: this is a special case of slice with strides, remove it
|
||||
# though it's nice that it can't change size
|
||||
def flip(self, *axis): return self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
||||
def flip(self, axis : Tuple[int, ...]) -> ShapeTracker:
|
||||
return self.stride(tuple(-1 if i in axis else 1 for i in range(len((self.shape)))))
|
||||
|
||||
# *** under this line are not invertible ***
|
||||
|
||||
# TODO: take this functionality out of slice
|
||||
def pad(self, *arg):
|
||||
def pad(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple)
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
return self.shrink(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
return self.shrink(tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)))
|
||||
|
||||
# TODO: take the pad functionality out of shrink
|
||||
def shrink(self, *arg):
|
||||
def shrink(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple)
|
||||
assert len(arg) == len(self.shape)
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
zeroview = ZeroView(self.shape, arg)
|
||||
self.views[-1] = View([y-x for x,y in arg], self.strides, self.offset+offset)
|
||||
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
|
||||
if zeroview.expr != "valid=valid":
|
||||
# if we add a ZeroView, we add another (stock) view also for modding
|
||||
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
|
||||
return self
|
||||
|
||||
def expand(self, *new_shape):
|
||||
def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(new_shape, tuple)
|
||||
assert all(isinstance(x, int) for x in new_shape)
|
||||
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
strides = [s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape))]
|
||||
strides : Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
|
||||
self.views[-1] = View(new_shape, strides, self.offset)
|
||||
return self
|
||||
|
||||
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
||||
def stride(self, *mul):
|
||||
def stride(self, mul : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(mul, tuple)
|
||||
assert all(isinstance(x, int) for x in mul)
|
||||
strides = [z*m for z,m in zip(self.strides, mul)]
|
||||
new_shape = [(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]
|
||||
strides = tuple(z*m for z,m in zip(self.strides, mul))
|
||||
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
self.views[-1] = View(new_shape, strides, self.offset + offset)
|
||||
return self
|
||||
|
|
|
@ -96,6 +96,7 @@ class NumNode(Node):
|
|||
self.b, self.min, self.max = num, num, num
|
||||
|
||||
class OpNode(Node):
|
||||
op : str
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a,b
|
||||
self.min, self.max = self.minmax(a,b)
|
||||
|
@ -105,6 +106,7 @@ class OpNode(Node):
|
|||
return f"({self.a}{self.op}{self.b})"
|
||||
|
||||
class RedNode(Node):
|
||||
op : str
|
||||
def __init__(self, nodes:List[Node]):
|
||||
self.nodes = nodes
|
||||
self.min, self.max = self.minmax(nodes)
|
||||
|
|
|
@ -168,12 +168,12 @@ class Tensor:
|
|||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
for y in args:
|
||||
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
|
||||
args = [self] + list(args)
|
||||
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in args)]
|
||||
slc = [[(0, s) for s in self.shape] for _ in args]
|
||||
catargs = [self] + list(args)
|
||||
shape_cumsum = [0, *itertools.accumulate(y.shape[dim] for y in catargs)]
|
||||
slc = [[(0, s) for s in self.shape] for _ in catargs]
|
||||
for s,k in zip(slc, shape_cumsum):
|
||||
s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(args, slc)])
|
||||
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
|
||||
|
||||
# TODO: make this nicer with syntactic sugar in slice
|
||||
def chunk(self, num, dim):
|
||||
|
|
Loading…
Reference in New Issue