1
0
Fork 0

fix folding of float4 add/mul (#1060)

pull/1051/head
Rayan Hatout 2023-06-27 04:59:29 +01:00 committed by GitHub
parent a98e361da0
commit 23648538fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -7,6 +7,7 @@ from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mn
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable
@ -343,8 +344,11 @@ class Linearizer:
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
if x.op in {BinaryOps.ADD, BinaryOps.MUL}:
# Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op
srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0)
x.src = tuple(srcs)
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
# TODO: fold float4 into a single uop when possible.
if x.op.__class__ in {ReduceOps, FusedOps}:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else: