From e6f19d4ce2a072d3dc3ba95c80c21c552f1753c8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 9 Feb 2023 13:03:48 -0600 Subject: [PATCH] assume all generic exec ast have ProcessingOp --- tinygrad/lazy.py | 6 +++--- tinygrad/ops.py | 18 ++++-------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 3aacb446c..abdf50b1e 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,7 +4,7 @@ import sys, weakref from weakref import WeakValueDictionary from tinygrad.helpers import ConvArgs, prod from tinygrad.shape import ShapeTracker -from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, DEBUG +from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, DEBUG, GenericExecAST from tinygrad.graph import log_op from tinygrad.helpers import getenv @@ -132,7 +132,7 @@ class LazyBuffer: else: # movement ops aren't an AST, just run them real_src = src.realize(self.device) - self.realized = real_src.movement_op(self.op.op, self.op.arg) # movement_op stays + self.realized = real_src.movement_op(self.op.op, self.op.arg) ast = LazyOp(self.op.op, (real_src, )) elif self.optype == ProcessingOps: ast = self.op # no ast modifications for ProcessingOps elif self.optype == ReduceOps: ast = _ast_reduceops(self) @@ -319,7 +319,7 @@ class LazyBuffer: x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_))) C = C._replace(px=0, px_=0, py=0, py_=0) - if NOCONV or not getattr(x.dbuffer, "processing_op", False): + if NOCONV or not issubclass(x.dbuffer, GenericExecAST): # universal conv, just mul and reduce x = x.movement_op(MovementOps.STRIDED, ( (C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8cdeff839..51727e1a7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -66,20 +66,10 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method @classmethod def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x): srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src] - if ast.op in UnaryOps: - ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf)) - elif ast.op in BinaryOps: - assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}" - ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, srcs[1].buf)) - elif ast.op in ReduceOps: - assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}" - ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, ast.arg)) - elif ast.op in MovementOps: - ret = srcs[0].movement_op(ast.op, ast.arg) - elif ast.op in ProcessingOps: - ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, srcs[1].buf, ast.arg)) - else: - raise TypeError("unknown op") + if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}" + if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}" + if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg) + else: ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](*([x.buf for x in srcs] + ([ast.arg] if ast.arg else [])))) if output_buffer is not None: assert output_buffer.shape == ret.shape output_buffer.buf = ret.buf