perf: remove cast and revert back to isinstance (#1694)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>pull/1699/head
parent
8b354b3f73
commit
328cf2e86a
|
@ -150,7 +150,7 @@ class LazyBuffer:
|
|||
for x in self.op.buffers: x.realize()
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
# put CAST before the final RESHAPE
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
|
||||
|
@ -315,7 +315,8 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
|||
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = cast(LazyBuffer, bx.op.src[0])
|
||||
assert isinstance(bx.op.src[0], LazyBuffer)
|
||||
bx = bx.op.src[0]
|
||||
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
|
||||
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
|
@ -325,7 +326,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
|||
|
||||
def _realize_contiguous(buffer: LazyBuffer) -> None:
|
||||
realized = buffer.op.src[0].realize().realized
|
||||
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape):
|
||||
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and realized is not None and realized.size == prod(buffer.shape):
|
||||
# no need to run an AST, this is already contiguous
|
||||
buffer.realized = realized
|
||||
else:
|
||||
|
|
|
@ -96,12 +96,12 @@ class Interpreted:
|
|||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg)
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src]
|
||||
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
|
|
Loading…
Reference in New Issue