|
|
|
@ -1,10 +1,11 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Union, Sequence, Final, Set
|
|
|
|
|
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set
|
|
|
|
|
import itertools, math, functools
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from enum import Enum, auto
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
|
|
|
|
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same, getenv
|
|
|
|
|
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
|
|
|
|
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
|
|
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
@ -20,7 +21,8 @@ class UOps(Enum):
|
|
|
|
|
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
|
|
|
|
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
|
|
|
|
|
|
|
|
|
class UOp(NamedTuple):
|
|
|
|
|
@dataclass
|
|
|
|
|
class UOp:
|
|
|
|
|
uop: UOps
|
|
|
|
|
dtype: Optional[DType]
|
|
|
|
|
vin: Tuple[UOp, ...]
|
|
|
|
@ -201,10 +203,12 @@ class Linearizer(Kernel):
|
|
|
|
|
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
|
|
|
|
|
|
|
|
|
# global and local loops
|
|
|
|
|
def render_loop(xx:List[Variable]):
|
|
|
|
|
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
|
|
|
|
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
|
|
|
|
|
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
|
|
|
|
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
|
|
|
|
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
|
|
|
|
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
|
|
|
|
|
self.loop_uops.update(new_loops)
|
|
|
|
|
return tuple(new_loops.values())
|
|
|
|
|
def end_loop(xx:List[Variable]):
|
|
|
|
|
for x in xx[::-1]:
|
|
|
|
|
if not isinstance(x, NumNode) and x.expr is not None:
|
|
|
|
@ -261,7 +265,7 @@ class Linearizer(Kernel):
|
|
|
|
|
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
|
|
|
|
|
|
|
|
|
# reduce loop
|
|
|
|
|
render_loop(reduce_idxs)
|
|
|
|
|
loop_ctx = render_loop(reduce_idxs)
|
|
|
|
|
|
|
|
|
|
# barrier for fast GEMM
|
|
|
|
|
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
|
|
|
@ -292,6 +296,7 @@ class Linearizer(Kernel):
|
|
|
|
|
for y in range(by):
|
|
|
|
|
for x in range(bx):
|
|
|
|
|
for j in range(acc_reds):
|
|
|
|
|
# TODO: make this a proper op with PHI node
|
|
|
|
|
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
|
|
|
|
i += wmma_sz[2]
|
|
|
|
|
else:
|
|
|
|
@ -304,7 +309,7 @@ class Linearizer(Kernel):
|
|
|
|
|
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
|
|
|
|
|
|
|
|
|
# run early AST (with reduce)
|
|
|
|
|
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True)
|
|
|
|
|
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
|
|
|
|
|
|
|
|
|
# end the reduce loop
|
|
|
|
|
end_loop(reduce_idxs)
|
|
|
|
@ -342,13 +347,13 @@ class Linearizer(Kernel):
|
|
|
|
|
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
|
|
|
|
|
|
|
|
|
# late reduce loop
|
|
|
|
|
render_loop(end_local_idxs)
|
|
|
|
|
loop_ctx = render_loop(end_local_idxs)
|
|
|
|
|
|
|
|
|
|
# load localbufs
|
|
|
|
|
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
|
|
|
|
|
|
|
|
|
# there's no AST here (and there's no shape for the reduce LazyOp)
|
|
|
|
|
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
|
|
|
|
|
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
|
|
|
|
|
|
|
|
|
# end the late reduce loop
|
|
|
|
|
end_loop(end_local_idxs)
|
|
|
|
@ -379,6 +384,11 @@ class Linearizer(Kernel):
|
|
|
|
|
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
|
|
|
|
self.uops = nu
|
|
|
|
|
|
|
|
|
|
# maybe graph the uops
|
|
|
|
|
if getenv("GRAPHUOPS"):
|
|
|
|
|
from tinygrad.graph import graph_uops
|
|
|
|
|
graph_uops(self.uops)
|
|
|
|
|
|
|
|
|
|
# restore backups
|
|
|
|
|
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
|
|
|
|
|
|
|
|
|
@ -409,7 +419,7 @@ class Linearizer(Kernel):
|
|
|
|
|
if cachable: self.saved_exprs[key] = self.uops[-1]
|
|
|
|
|
return self.uops[-1]
|
|
|
|
|
|
|
|
|
|
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False) -> List[UOp]:
|
|
|
|
|
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
|
|
|
|
|
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
|
|
|
|
if x.op in BufferOps: return loaded_buffers[x.arg]
|
|
|
|
|
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op
|
|
|
|
@ -421,15 +431,17 @@ class Linearizer(Kernel):
|
|
|
|
|
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
|
|
|
|
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
|
|
|
|
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
|
|
|
|
values = [self.ast_parse(v, acc, offs, loaded_buffers) for v in x.src]
|
|
|
|
|
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
|
|
|
|
|
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
|
|
|
|
if x.op in ops:
|
|
|
|
|
ret = []
|
|
|
|
|
input_acc = acc[:]
|
|
|
|
|
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
|
|
|
|
|
new_val = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
|
|
|
|
# NOTE: we could apply the phi node to only the last change, but this breaks CLANG with nested max(x,y)
|
|
|
|
|
acc[off] = self.uop(UOps.PHI, dtypes.float32, (acc[off], new_val))
|
|
|
|
|
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
|
|
|
|
ret.append((idx, acc[off]))
|
|
|
|
|
for off in range(len(acc)):
|
|
|
|
|
if input_acc[off] != acc[off]:
|
|
|
|
|
acc[off] = self.uop(UOps.PHI, dtypes.float32, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
|
|
|
|
else:
|
|
|
|
|
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
|
|
|
|
|
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])
|
|
|
|
|