1
0
Fork 0

MEM -> LOAD (#2492)

* MEM -> LOAD

* keep legacy working
pull/2495/head
George Hotz 2023-11-28 16:46:37 -08:00 committed by GitHub
parent a739c6646e
commit ab5d14d4ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 29 additions and 26 deletions

View File

@ -6,6 +6,9 @@ from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float('inf'), float('nan')
# HACK: it used to be called MEM
setattr(BufferOps, "MEM", BufferOps.LOAD)
# kernel unpacker
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)

View File

@ -27,29 +27,29 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms):
@unittest.skipIf(CI and Device.DEFAULT=="CUDA", "failed on CUDA CI")
class TestLinearizerFailures(unittest.TestCase):
def test_failure_1(self):
ast = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)))),), arg=(32, 16, 1)), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None)
ast = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)))),), arg=(32, 16, 1)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None)
helper_test_lin(Linearizer(ast), [], failed_platforms=["CLANG"])
def test_failure_2(self):
ast = LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))))),), arg=(32, 2, 37, 9, 1, 1))
ast = LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))))),), arg=(32, 2, 37, 9, 1, 1))
opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)]
helper_test_lin(Linearizer(ast), opts, failed_platforms=["CPU", "TORCH"])
def test_failure_3(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1))
opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)]
# METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)}
helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"])
def test_failure_4(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 4, 1, 12, 2, 29), strides=(0, 0, 0, 2, 0, 216, 1, 8), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 1), (0, 11), (0, 2), (0, 27)), contiguous=False), View(shape=(1, 1, 1, 4, 22, 84), strides=(0, 0, 0, 696, 58, 1), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 12), (0, 58)), contiguous=False), View(shape=(1, 1, 1, 4, 2, 11, 3, 28), strides=(0, 0, 0, 1848, 924, 84, 28, 1), offset=0, mask=None, contiguous=True))))),), arg=(1, 1, 1, 4, 1, 11, 1, 28))
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 4, 1, 12, 2, 29), strides=(0, 0, 0, 2, 0, 216, 1, 8), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 1), (0, 11), (0, 2), (0, 27)), contiguous=False), View(shape=(1, 1, 1, 4, 22, 84), strides=(0, 0, 0, 696, 58, 1), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 12), (0, 58)), contiguous=False), View(shape=(1, 1, 1, 4, 2, 11, 3, 28), strides=(0, 0, 0, 1848, 924, 84, 28, 1), offset=0, mask=None, contiguous=True))))),), arg=(1, 1, 1, 4, 1, 11, 1, 28))
opts = [Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)]
# related to OptOps.NOLOCALS
# IndexError: list index out of range
helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL"])
def test_failure_5(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(1, 1, 1, 1, 1, 1, 1, 1))
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(1, 1, 1, 1, 1, 1, 1, 1))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
# EXEC_ERROR, it has no global_size
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
@ -62,14 +62,14 @@ class TestLinearizerFailures(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT=="LLVM", "Segmentation fault")
def test_failure_7(self):
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))))),), arg=(512, 32, 1, 34, 1, 34))
ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))))),), arg=(512, 32, 1, 34, 1, 34))
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)]
# test/test_linearizer_failures.py Fatal Python error: Segmentation fault
helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"])
@unittest.skipIf(Device.DEFAULT=="LLVM" and not OSX, "Segmentation fault on ubuntu")
def test_failure_8(self):
ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None)
ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None)
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]
# fatal error: bracket nesting level exceeded maximum of 256
# note: use -fbracket-depth=N to increase maximum nesting level

View File

@ -6,8 +6,8 @@ from tinygrad.helpers import dtypes
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)

View File

@ -336,8 +336,8 @@ class Kernel:
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.MEM and mul_op.src[0].arg.dtype == tc.dtype_in): continue
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.MEM and mul_op.src[1].arg.dtype == tc.dtype_in): continue
if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.LOAD and mul_op.src[0].arg.dtype == tc.dtype_in): continue
if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue
buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0]
@ -486,7 +486,7 @@ class Kernel:
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM:
self.reduceop.src[0].src[0].op == BufferOps.LOAD and self.reduceop.src[0].src[1].op == BufferOps.LOAD:
buf0 = self.bufs.index(self.reduceop.src[0].src[0].arg)
buf1 = self.bufs.index(self.reduceop.src[0].src[1].arg)
buf0_strides = self.sts[buf0].real_strides()

View File

@ -345,7 +345,7 @@ class Linearizer(Kernel):
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.MEM, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
# end the late reduce loop
self.load_cache.clear()

View File

@ -108,7 +108,7 @@ def fix_schedule_for_images(schedule:List[ScheduleItem]):
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
si.out.dtype = dtypes.float32
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if b.op != BufferOps.LOAD: continue
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes()):
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
if si.inputs[b.arg.idx-1].realized:
@ -132,9 +132,9 @@ def fix_schedule_for_images(schedule:List[ScheduleItem]):
# fix input dtypes to match what they actually are
replacements = {}
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if b.op != BufferOps.LOAD: continue
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
replacements[b] = LazyOp(BufferOps.LOAD, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
if replacements: ast = ast.map_buffers(replacements)
# fix the ops to create the output dtype

View File

@ -73,7 +73,7 @@ def log_schedule_item(si: ScheduleItem):
# get inputs for shapetrackers
input_to_st = defaultdict(list)
for lo in si.ast.get_lazyops():
if lo.op != BufferOps.MEM: continue
if lo.op != BufferOps.LOAD: continue
input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st)
# add them to the graph, potentially with a movement op separating them

View File

@ -67,7 +67,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
for x in op.buffers:
st = x.st.simplify().unbind()
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
replacements[x] = LazyOp(BufferOps.LOAD, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
else:

View File

@ -14,7 +14,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto()
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto(); FROM_UNDERLYING = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); FROM_UNDERLYING = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
@ -95,7 +95,7 @@ class FlopCounter:
return ret
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.MEM: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},

View File

@ -18,7 +18,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
assert all(si.out.device == x.device for x in si.inputs), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
@ -31,7 +31,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
for i,a in enumerate(si.inputs):
# TODO: if this is contiguous it's fine
if a.realized == si.out.output_buffer:
if any(not x.arg.st.contiguous for x in si.ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1):
if any(not x.arg.st.contiguous for x in si.ast.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i+1):
si.out.output_buffer = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape

View File

@ -32,7 +32,7 @@ def einsum_mulacc(einsum, get_strides, expand):
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x.toCPU(), BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
BufferOps.LOAD: lambda x: x.toCPU(), BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.FROM_UNDERLYING: RawNumpyBuffer.fromCPU,
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),

View File

@ -55,5 +55,5 @@ class RawDiskBuffer(RawBufferMapped):
self.readinto(instance._buffer())
return instance
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.LOAD: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskDevice = Interpreted(RawDiskBuffer, disk_fxn_for_op)

View File

@ -49,7 +49,7 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
@diskcache
def compile_hip(prg) -> bytes:
prog = hip.hiprtcCreateProgram(prg, "<null>", [], [])
arch = "gfx1100" if MOCKHIP else hip.hipGetDeviceProperties(HIP.default_device).gcnArchName
arch = "gfx1100" if MOCKHIP else hip.hipGetDeviceProperties(HIP.default_device).gcnArchName
hip.hiprtcCompileProgram(prog, [f'--offload-arch={arch}'])
return hip.hiprtcGetCode(prog)

View File

@ -36,7 +36,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {
# TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device),
BufferOps.MEM: lambda x: x._get_buf(), BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
BufferOps.LOAD: lambda x: x._get_buf(), BufferOps.FROM_UNDERLYING: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),