1
0
Fork 0

perf: remove cast and revert back to isinstance (#1694)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
pull/1699/head
Roelof van Dijk 2023-08-27 20:15:52 +02:00 committed by GitHub
parent 8b354b3f73
commit 328cf2e86a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 6 deletions

View File

@ -150,7 +150,7 @@ class LazyBuffer:
for x in self.op.buffers: x.realize()
# HACK: image shape can be wrong, hot cast it back to a normal float
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if self.op.op == MovementOps.RESHAPE:
# put CAST before the final RESHAPE
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg)
@ -315,7 +315,8 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = cast(LazyBuffer, bx.op.src[0])
assert isinstance(bx.op.src[0], LazyBuffer)
bx = bx.op.src[0]
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
@ -325,7 +326,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
def _realize_contiguous(buffer: LazyBuffer) -> None:
realized = buffer.op.src[0].realize().realized
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape):
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and realized is not None and realized.size == prod(buffer.shape):
# no need to run an AST, this is already contiguous
buffer.realized = realized
else:

View File

@ -96,12 +96,12 @@ class Interpreted:
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg)
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src]
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
if DEBUG >= 3: st = time.perf_counter()
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.