From 23648538fa2113c554c3c7589ca1d9b38aeff545 Mon Sep 17 00:00:00 2001 From: Rayan Hatout Date: Tue, 27 Jun 2023 04:59:29 +0100 Subject: [PATCH] fix folding of float4 add/mul (#1060) --- tinygrad/codegen/linearizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 082deb192..8a97b9f18 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -7,6 +7,7 @@ from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mn from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps +from tinygrad.runtime.lib import RawConst from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape from tinygrad.shape.symbolic import Variable @@ -343,8 +344,11 @@ class Linearizer: x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg) if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg) + if x.op in {BinaryOps.ADD, BinaryOps.MUL}: + # Reorder sources to put constants first so get_grouped_maybe_float4 can fold the op + srcs = sorted(x.src, key=lambda x: (x.realized.__class__ != RawConst) if x.__class__ == LazyBuffer else 0) + x.src = tuple(srcs) values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] - # TODO: fold float4 into a single uop when possible. if x.op.__class__ in {ReduceOps, FusedOps}: ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)] else: