autopad shapetracker for BEAM (#2375)
* autopad shapetracker for BEAM * OptOps.PADTO * skip that test for now * correct padding reduce axis * just 32 * avoid more than double the FLOPs * cleanups * test case * no support for triton and llvm yet * typos * symbolic shape would not work * cannot PADTO with MAX kernel * advance db version * no breaking change - don't advance db version * is triton just python? * Revert "is triton just python?" This reverts commit 17e776c25587615e33a3634c2fb0bb8591ce65d4. * Revert "Revert "is triton just python?"" This reverts commit 6c434c01e1c4b0ea0431ec18632cd859fb3cf260. * support llvm * is it really passing in CI only? * update tests * oh triton test passed * simpler * revert that, with a test * check if st are the same * Revert "check if st are the same" This reverts commit d2a5eac110a5da1af82a2728c883779ef69c3cad. * update the db version * rebase artifactpull/2396/head
parent
162db466c3
commit
8798d120bb
|
@ -0,0 +1,40 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from test.external.fuzz_linearizer import run_linearizer
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
N = 17**3
|
||||
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
c = a @ b
|
||||
sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
|
||||
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
|
||||
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
|
||||
lin.apply_opt(Opt(op=OptOps.PADTO, axis=2, amt=32))
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
print(f"{lin.applied_opts=}")
|
||||
|
||||
run_linearizer(lin)
|
||||
quit()
|
||||
|
||||
###
|
||||
|
||||
a = Tensor.rand(61, 61).sum(axis=0)
|
||||
sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
|
||||
# lin.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=32))
|
||||
|
||||
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
|
||||
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
|
||||
run_linearizer(lin)
|
|
@ -17,6 +17,25 @@ class TestBeamSearch(unittest.TestCase):
|
|||
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = (a+1).realize()
|
||||
|
||||
def test_big_prime_number(self):
|
||||
a = Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
c = (a@b).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_big_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_shrink_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = Tensor.rand(400, 367)
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_no_mutate_rawbuffers(self):
|
||||
a = Tensor.rand(3, 3).realize()
|
||||
desired = a.numpy() + 1
|
||||
|
|
|
@ -184,6 +184,10 @@ if (getenv('LLVM') or getenv('CUDA')) and CI:
|
|||
# error: casting to type 'half' is not allowed
|
||||
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
|
||||
|
||||
# TODO: this somehow passes in CI but does not pass if run locally
|
||||
if getenv('GPU') or getenv('METAL') or getenv('LLVM') or getenv('CLANG'):
|
||||
backend_test.exclude('test_MaxPool3d_stride_padding_cpu')
|
||||
|
||||
# disable model tests for now since they are slow
|
||||
if not getenv("MODELTESTS"):
|
||||
for x in backend_test.test_suite:
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.ops import Compiled, Device, LoadOps
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.helpers import dtypes, prod, getenv, CI
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
|
@ -487,6 +487,36 @@ class TestLinearizerOpts(unittest.TestCase):
|
|||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True)
|
||||
|
||||
def test_padto_matmul(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
|
||||
N = 17 * 17
|
||||
Tensor.manual_seed(289)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
helper_linearizer_opt(a@b, [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 1, 32)],
|
||||
[Opt(OptOps.PADTO, 2, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
|
||||
# can optimize further post PADTO
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4)],
|
||||
])
|
||||
|
||||
def test_padto_max(self):
|
||||
# pad uses invalid value 0, so max is not allowed
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
|
||||
N = 17 * 17
|
||||
a = -Tensor.ones(N, N)
|
||||
with self.assertRaises(AssertionError):
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
|
||||
def test_padto_where(self):
|
||||
# pad uses invalid value 0, so kernel with max is not allowed
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
|
||||
N = 17 * 17
|
||||
a = (Tensor.rand(N, N).max(axis=0) > 1).where(1, 0)
|
||||
with self.assertRaises(AssertionError):
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from __future__ import annotations
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
@ -10,7 +11,7 @@ from dataclasses import dataclass
|
|||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
@ -400,8 +401,8 @@ class Kernel:
|
|||
axis = -1
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
assert isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless"
|
||||
if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
else:
|
||||
amt = -1
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
|
@ -450,6 +451,18 @@ class Kernel:
|
|||
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
elif opt.op == OptOps.PADTO:
|
||||
assert not vars_from_ast(self.ast), "does not work with symbolic shape"
|
||||
assert all(op.op is not ReduceOps.MAX for op in self.ast.get_lazyops()), "cannot pad with MAX"
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if self.sts[i].shape[axis] != 1:
|
||||
assert self.sts[i].shape[axis] > amt//2, "pad adds more than double the work"
|
||||
if (ru := round_up(self.sts[i].shape[axis], amt) - self.sts[i].shape[axis]):
|
||||
# pad right seems to be faster
|
||||
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
||||
padded = True
|
||||
assert padded, "nothing was padded"
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set
|
||||
import itertools, math, functools
|
||||
import itertools, math, functools, operator
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
|
@ -86,7 +86,6 @@ class Linearizer(Kernel):
|
|||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
|
@ -131,7 +130,6 @@ class Linearizer(Kernel):
|
|||
amt = len(out_tokens)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens))
|
||||
store_offset = store_offset_new
|
||||
|
||||
|
@ -143,7 +141,8 @@ class Linearizer(Kernel):
|
|||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
return stores
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
|
@ -389,6 +388,21 @@ class Linearizer(Kernel):
|
|||
u.vin = tuple(new if x is old else x for x in u.vin)
|
||||
self.uops.remove(old)
|
||||
|
||||
# fix loop scope, push CONST and ALU upward out of loop if it does not depend on the loop
|
||||
loop_stack: List[List[UOp]] = [[]]
|
||||
for u in self.uops:
|
||||
if not loop_stack[-1]: loop_stack[-1].append(u)
|
||||
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
||||
elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u)
|
||||
else:
|
||||
parents = get_recursive_parents([u])
|
||||
for i in reversed(range(len(loop_stack))):
|
||||
# check backwards and put the uop in the first encounter with some dependency
|
||||
if any(x in parents for x in loop_stack[i]) or i == 0:
|
||||
loop_stack[i].append(u)
|
||||
break
|
||||
self.uops = functools.reduce(operator.__add__, loop_stack, [])
|
||||
|
||||
# uops optimization
|
||||
changed_something = True
|
||||
while changed_something:
|
||||
|
|
|
@ -13,6 +13,7 @@ actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,
|
|||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
|
|
|
@ -177,7 +177,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
|||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 9
|
||||
VERSION = 10
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
|
|
@ -189,7 +189,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
if len(vin) > 3: kk(lang.render_if(r[vin[3]]))
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
if len(vin) > 3: kk("}")
|
||||
elif uop == UOps.CAST and dtype is not None:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
|
|
|
@ -140,7 +140,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|||
lvars[backward] = lvars[u]
|
||||
if uop == UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
|
||||
else: store_op()
|
||||
if uop == UOps.ALU:
|
||||
lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype)
|
||||
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)
|
||||
|
|
|
@ -95,7 +95,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
|||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.STORE:
|
||||
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
||||
kk(f"tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
signatures.append(signature_dtypes[args[1]])
|
||||
|
|
Loading…
Reference in New Issue