1
0
Fork 0

assume all generic exec ast have ProcessingOp

pull/548/head
George Hotz 2023-02-09 13:03:48 -06:00
parent 78795e3507
commit e6f19d4ce2
2 changed files with 7 additions and 17 deletions

View File

@ -4,7 +4,7 @@ import sys, weakref
from weakref import WeakValueDictionary
from tinygrad.helpers import ConvArgs, prod
from tinygrad.shape import ShapeTracker
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, DEBUG
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, DEBUG, GenericExecAST
from tinygrad.graph import log_op
from tinygrad.helpers import getenv
@ -132,7 +132,7 @@ class LazyBuffer:
else:
# movement ops aren't an AST, just run them
real_src = src.realize(self.device)
self.realized = real_src.movement_op(self.op.op, self.op.arg) # movement_op stays
self.realized = real_src.movement_op(self.op.op, self.op.arg)
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == ProcessingOps: ast = self.op # no ast modifications for ProcessingOps
elif self.optype == ReduceOps: ast = _ast_reduceops(self)
@ -319,7 +319,7 @@ class LazyBuffer:
x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
C = C._replace(px=0, px_=0, py=0, py_=0)
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
if NOCONV or not issubclass(x.dbuffer, GenericExecAST):
# universal conv, just mul and reduce
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),

View File

@ -66,20 +66,10 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
if ast.op in UnaryOps:
ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf))
elif ast.op in BinaryOps:
assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, srcs[1].buf))
elif ast.op in ReduceOps:
assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, ast.arg))
elif ast.op in MovementOps:
ret = srcs[0].movement_op(ast.op, ast.arg)
elif ast.op in ProcessingOps:
ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](srcs[0].buf, srcs[1].buf, ast.arg))
else:
raise TypeError("unknown op")
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
else: ret = type(srcs[0])(srcs[0].fxn_for_op[ast.op](*([x.buf for x in srcs] + ([ast.arg] if ast.arg else []))))
if output_buffer is not None:
assert output_buffer.shape == ret.shape
output_buffer.buf = ret.buf