1
0
Fork 0

autopad shapetracker for BEAM (#2375)

* autopad shapetracker for BEAM

* OptOps.PADTO

* skip that test for now

* correct padding reduce axis

* just 32

* avoid more than double the FLOPs

* cleanups

* test case

* no support for triton and llvm yet

* typos

* symbolic shape would not work

* cannot PADTO with MAX kernel

* advance db version

* no breaking change - don't advance db version

* is triton just python?

* Revert "is triton just python?"

This reverts commit 17e776c25587615e33a3634c2fb0bb8591ce65d4.

* Revert "Revert "is triton just python?""

This reverts commit 6c434c01e1c4b0ea0431ec18632cd859fb3cf260.

* support llvm

* is it really passing in CI only?

* update tests

* oh triton test passed

* simpler

* revert that, with a test

* check if st are the same

* Revert "check if st are the same"

This reverts commit d2a5eac110a5da1af82a2728c883779ef69c3cad.

* update the db version

* rebase artifact
pull/2396/head
chenyu 2023-11-22 21:05:25 -05:00 committed by GitHub
parent 162db466c3
commit 8798d120bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 138 additions and 12 deletions

40
extra/autopad.py 100644
View File

@ -0,0 +1,40 @@
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
from test.external.fuzz_linearizer import run_linearizer
from tinygrad.codegen.kernel import Opt, OptOps
N = 17**3
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
c = a @ b
sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=2, amt=32))
lin.hand_coded_optimizations()
lin.linearize()
print(f"{lin.applied_opts=}")
run_linearizer(lin)
quit()
###
a = Tensor.rand(61, 61).sum(axis=0)
sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
# lin.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=0, amt=32))
lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
lin.hand_coded_optimizations()
lin.linearize()
run_linearizer(lin)

View File

@ -17,6 +17,25 @@ class TestBeamSearch(unittest.TestCase):
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
def test_big_prime_number(self):
a = Tensor.rand(367, 367)
b = Tensor.rand(367, 367)
c = (a@b).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
def test_variable_big_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = Tensor.rand(367, 367)
b = Tensor.rand(367, 367)
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
def test_variable_shrink_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = Tensor.rand(400, 367)
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
def test_no_mutate_rawbuffers(self):
a = Tensor.rand(3, 3).realize()
desired = a.numpy() + 1

View File

@ -184,6 +184,10 @@ if (getenv('LLVM') or getenv('CUDA')) and CI:
# error: casting to type 'half' is not allowed
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
# TODO: this somehow passes in CI but does not pass if run locally
if getenv('GPU') or getenv('METAL') or getenv('LLVM') or getenv('CLANG'):
backend_test.exclude('test_MaxPool3d_stride_padding_cpu')
# disable model tests for now since they are slow
if not getenv("MODELTESTS"):
for x in backend_test.test_suite:

View File

@ -7,7 +7,7 @@ from tinygrad.ops import Compiled, Device, LoadOps
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.helpers import dtypes, prod
from tinygrad.helpers import dtypes, prod, getenv, CI
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@ -487,6 +487,36 @@ class TestLinearizerOpts(unittest.TestCase):
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
], apply_tc=True)
def test_padto_matmul(self):
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
N = 17 * 17
Tensor.manual_seed(289)
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
helper_linearizer_opt(a@b, [
[Opt(OptOps.PADTO, 0, 32)],
[Opt(OptOps.PADTO, 1, 32)],
[Opt(OptOps.PADTO, 2, 32)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
# can optimize further post PADTO
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4)],
])
def test_padto_max(self):
# pad uses invalid value 0, so max is not allowed
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
N = 17 * 17
a = -Tensor.ones(N, N)
with self.assertRaises(AssertionError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
def test_padto_where(self):
# pad uses invalid value 0, so kernel with max is not allowed
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
N = 17 * 17
a = (Tensor.rand(N, N).max(axis=0) > 1).where(1, 0)
with self.assertRaises(AssertionError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
if __name__ == '__main__':
unittest.main()

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape
@ -10,7 +11,7 @@ from dataclasses import dataclass
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto() # noqa: E702
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
@ -400,8 +401,8 @@ class Kernel:
axis = -1
if opt.amt is not None:
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
assert isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless"
if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift"
else:
amt = -1
if opt.op == OptOps.LOCAL: # cyan
@ -450,6 +451,18 @@ class Kernel:
assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals"
assert not self.dont_use_locals, "already not using locals"
self.dont_use_locals = True
elif opt.op == OptOps.PADTO:
assert not vars_from_ast(self.ast), "does not work with symbolic shape"
assert all(op.op is not ReduceOps.MAX for op in self.ast.get_lazyops()), "cannot pad with MAX"
padded = False
for i,st in enumerate(self.sts):
if self.sts[i].shape[axis] != 1:
assert self.sts[i].shape[axis] > amt//2, "pad adds more than double the work"
if (ru := round_up(self.sts[i].shape[axis], amt) - self.sts[i].shape[axis]):
# pad right seems to be faster
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
padded = True
assert padded, "nothing was padded"
return self.simplify_ones()
def required_optimizations(self, early_only=False):

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set
import itertools, math, functools
import itertools, math, functools, operator
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
@ -86,7 +86,6 @@ class Linearizer(Kernel):
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
if key not in self.load_cache:
if acc is not None:
assert valid.min == 1
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False)
elif this_const is not None:
self.load_cache[key] = self.const(this_const, localtype)
@ -131,7 +130,6 @@ class Linearizer(Kernel):
amt = len(out_tokens)
idx, valid = self.sts[i].expr_idxs(k)
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
assert valid.min == 1, "stores are always valid"
store_offset_new[k] = self.uop(UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens))
store_offset = store_offset_new
@ -143,7 +141,8 @@ class Linearizer(Kernel):
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx))
else:
rendered_idx = idx.render(self.render_ops, self)
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
return stores
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
@ -389,6 +388,21 @@ class Linearizer(Kernel):
u.vin = tuple(new if x is old else x for x in u.vin)
self.uops.remove(old)
# fix loop scope, push CONST and ALU upward out of loop if it does not depend on the loop
loop_stack: List[List[UOp]] = [[]]
for u in self.uops:
if not loop_stack[-1]: loop_stack[-1].append(u)
elif u.uop == UOps.LOOP: loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u)
else:
parents = get_recursive_parents([u])
for i in reversed(range(len(loop_stack))):
# check backwards and put the uop in the first encounter with some dependency
if any(x in parents for x in loop_stack[i]) or i == 0:
loop_stack[i].append(u)
break
self.uops = functools.reduce(operator.__add__, loop_stack, [])
# uops optimization
changed_something = True
while changed_something:

View File

@ -13,6 +13,7 @@ actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)])
actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),

View File

@ -177,7 +177,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 9
VERSION = 10
_db_connection = None
def db_connection():
global _db_connection

View File

@ -189,7 +189,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
r[u] = r[vin[0]]
elif uop == UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
if len(vin) > 3: kk(lang.render_if(r[vin[3]]))
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
if len(vin) > 3: kk("}")
elif uop == UOps.CAST and dtype is not None:
val = lang.render_cast([r[x] for x in vin], dtype)
if child_count[u] <= 1: r[u] = val

View File

@ -140,7 +140,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
lvars[backward] = lvars[u]
if uop == UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
if len(vin) > 3:
with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op()
else: store_op()
if uop == UOps.ALU:
lvars[u] = cast(bb, code_for_op[args](bb[-1], *[cast(bb, lvars[x], x.dtype, dtypes.float) for x in vin]), dtypes.float, dtype)
if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype)

View File

@ -95,7 +95,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
r[u] = r[vin[0]]
elif uop == UOps.STORE:
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
kk(f"tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
signatures.append(signature_dtypes[args[1]])