fix folding of float4 add/mul (#1060)
parent
a98e361da0
commit
23648538fa
|
@ -7,6 +7,7 @@ from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mn
|
|||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
|
@ -343,8 +344,11 @@ class Linearizer:
|
|||
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
if x.op in {BinaryOps.ADD, BinaryOps.MUL}:
|
||||
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
|
||||
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
|
||||
x.src = tuple(srcs)
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
# TODO: fold float4 into a single uop when possible.
|
||||
if x.op.__class__ in {ReduceOps, FusedOps}:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue