1
0
Fork 0

less phi, proper phi (#2241)

* less phi, proper phi

* disable flaky whisper test
pull/2244/head
George Hotz 2023-11-08 16:13:43 -08:00 committed by GitHub
parent 4c44d1344b
commit 38b7f5a7fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 37 additions and 22 deletions

View File

@ -230,8 +230,8 @@ jobs:
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
- name: Run ONNX
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
- name: Run whisper test
run: METAL=1 python -m pytest test/models/test_whisper.py
#- name: Run whisper test
# run: METAL=1 python -m pytest test/models/test_whisper.py
- name: Check Device.DEFAULT (WEBGPU) and print some source
run: |
WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"

View File

@ -137,6 +137,9 @@ class TestOps(unittest.TestCase):
def test_arange_big(self):
helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True)
def test_sum_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True)
def test_where(self):
helper_test_op(
[(100,)],

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Union, Sequence, Final, Set
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set
import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same, getenv
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
@ -20,7 +21,8 @@ class UOps(Enum):
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
class UOp(NamedTuple):
@dataclass
class UOp:
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
@ -201,10 +203,12 @@ class Linearizer(Kernel):
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# global and local loops
def render_loop(xx:List[Variable]):
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None}
self.loop_uops.update(new_loops)
return tuple(new_loops.values())
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:
@ -261,7 +265,7 @@ class Linearizer(Kernel):
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
# reduce loop
render_loop(reduce_idxs)
loop_ctx = render_loop(reduce_idxs)
# barrier for fast GEMM
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
@ -292,6 +296,7 @@ class Linearizer(Kernel):
for y in range(by):
for x in range(bx):
for j in range(acc_reds):
# TODO: make this a proper op with PHI node
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]]+locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]]+acc[i:i+wmma_sz[2]]), (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
i += wmma_sz[2]
else:
@ -304,7 +309,7 @@ class Linearizer(Kernel):
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
# run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
# end the reduce loop
end_loop(reduce_idxs)
@ -342,13 +347,13 @@ class Linearizer(Kernel):
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
render_loop(end_local_idxs)
loop_ctx = render_loop(end_local_idxs)
# load localbufs
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True) # type: ignore
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
# end the late reduce loop
end_loop(end_local_idxs)
@ -379,6 +384,11 @@ class Linearizer(Kernel):
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
# maybe graph the uops
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
graph_uops(self.uops)
# restore backups
self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup
@ -409,7 +419,7 @@ class Linearizer(Kernel):
if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1]
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False) -> List[UOp]:
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, offs, loaded_buffers) # cast isn't an ALU op
@ -421,15 +431,17 @@ class Linearizer(Kernel):
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, offs, loaded_buffers) for v in x.src]
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = []
input_acc = acc[:]
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
new_val = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
# NOTE: we could apply the phi node to only the last change, but this breaks CLANG with nested max(x,y)
acc[off] = self.uop(UOps.PHI, dtypes.float32, (acc[off], new_val))
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
ret.append((idx, acc[off]))
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uop(UOps.PHI, dtypes.float32, (input_acc[off], acc[off]) + tuple(loop_ctx))
else:
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])

View File

@ -127,7 +127,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
child_count[v] += 1
for u in uops:
uop,dtype,vin,args,_ = u
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == UOps.LOOP:
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
depth += 1
@ -166,7 +166,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
else:
val = lang.code_for_op[args](*[r[x] for x in vin])
assert child_count[u] != 0, f"childless ALU op found {u}"
if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue
r[u] = val
else:
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")

View File

@ -64,7 +64,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
module = ir.Module(name=__file__)
# extract global buffers
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {u.arg[0]:u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
@ -87,7 +87,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args,_ = u
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1]._block)

View File

@ -74,7 +74,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
}
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:
uop,dtype,vin,args,_ = u
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == UOps.LOOP:
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
depth += 1