symbolic codegen and exec (#1552)
* symbolic codegen and exec * fix and add test * no sketchy * merge_dicts type * dtypes._arg_int32pull/1512/head^2
parent
1e1d48b4e6
commit
11dd9b1741
|
@ -188,6 +188,8 @@ jobs:
|
|||
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_ops.py
|
||||
- name: Run JIT test
|
||||
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
|
||||
- name: Run symbolic shapetracker test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py
|
||||
- name: Check Device.DEFAULT
|
||||
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
#- name: Run webgpu pytest
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import Context, ContextVar, merge_dicts
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
|
@ -106,5 +106,17 @@ with Context(VARIABLE=1):
|
|||
...
|
||||
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
|
||||
|
||||
class TestMergeDicts(unittest.TestCase):
|
||||
def test_merge_dicts(self):
|
||||
a = {"a": 1, "b": 2}
|
||||
b = {"a": 1, "c": 3}
|
||||
c = {}
|
||||
d = {"a": 2, "b": 2}
|
||||
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
|
||||
assert merge_dicts([a, c]) == a
|
||||
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
|
||||
with self.assertRaises(AssertionError):
|
||||
merge_dicts([a, d])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,113 @@
|
|||
import unittest
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul_same_var_different_val(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, 4)
|
||||
b = Tensor.rand(7, 5)
|
||||
with self.assertRaises(AssertionError):
|
||||
f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
|
||||
expected = f(q, k, v).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
|
@ -26,17 +26,18 @@ class TestSymbolic(unittest.TestCase):
|
|||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t1.lazydata.st
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (i+j+k, 4)
|
||||
assert st.real_strides() == (4, 1)
|
||||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t1.lazydata.st
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (2*i+3, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
|
|
|
@ -283,6 +283,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
|||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert i // i == 1
|
||||
|
||||
def test_node_mod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
|
|
|
@ -9,7 +9,7 @@ from tinygrad.lazy import LazyBuffer
|
|||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
|
||||
VariableOrNum = Union[Variable, NumNode, Node]
|
||||
|
||||
# bottom ones are asm only
|
||||
|
@ -301,6 +301,9 @@ class Linearizer:
|
|||
# add global buffers
|
||||
for buf,name in self.arg_bufs.items():
|
||||
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
|
||||
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
|
@ -317,7 +320,7 @@ class Linearizer:
|
|||
if DEBUG >= 3: self.printbufs()
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
|
||||
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# parse AST
|
||||
|
@ -548,7 +551,7 @@ class Linearizer:
|
|||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
|
||||
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
|
||||
|
|
|
@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
|
|||
# early exit
|
||||
return
|
||||
|
||||
if k.opts.has_local:
|
||||
if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
|
@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer):
|
|||
while prod(k.sts[0].shape[:k.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
||||
if len(xb_choices):
|
||||
xb_choices = sorted(xb_choices)
|
||||
|
@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer):
|
|||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))):
|
||||
if (s:=k.full_unupcasted_shape[-1]) <= 32:
|
||||
if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
k.upcast()
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast()
|
||||
|
|
|
@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib
|
|||
from weakref import KeyedRef, ref
|
||||
from _weakref import _remove_dead_weakref # type: ignore
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
|
||||
from math import prod # noqa: F401 # pylint:disable=unused-import
|
||||
|
||||
ShapeType = Tuple[int, ...]
|
||||
|
@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
|
|||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for k,v in kvs}
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
||||
|
@ -115,6 +119,7 @@ class dtypes:
|
|||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
|
|
|
@ -2,8 +2,9 @@ from __future__ import annotations
|
|||
import functools, time
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -131,20 +132,24 @@ class ASTRunner:
|
|||
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
|
||||
return self
|
||||
|
||||
def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
|
||||
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs, force_wait=force_wait)
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
|
||||
if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None,
|
||||
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
|
||||
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
|
||||
if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
|
||||
(local_size + [1]*(3-len(local_size))) if local_size is not None else None,
|
||||
*rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += self.op_estimate
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += self.mem_estimate
|
||||
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
|
||||
return et
|
||||
|
@ -178,9 +183,11 @@ class Compiled:
|
|||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if not output.realized:
|
||||
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
|
||||
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
|
||||
# update the output var_vals from src
|
||||
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, output, self.linearizer_opts)
|
||||
|
@ -200,5 +207,5 @@ class Compiled:
|
|||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(k.bufs)
|
||||
prg.exec(k.bufs, var_vals=output.st.var_vals)
|
||||
return output.realized
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
|
@ -17,6 +17,7 @@ class CStyleLanguage(NamedTuple):
|
|||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
smem_prefix: str = ""
|
||||
arg_int_prefix: str = ""
|
||||
barrier: str = ""
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
|
@ -69,7 +70,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def render_local(self, name:str, size:int):
|
||||
return self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:int, _max:int) -> str:
|
||||
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
|
@ -78,6 +79,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
|
@ -128,7 +130,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
|||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
else:
|
||||
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
|
|
|
@ -38,7 +38,7 @@ class WGSLLanguage(CStyleLanguage):
|
|||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
return prg, global_size[::-1] if len(global_size) else [1], local_size
|
||||
|
||||
def render_for(self, expr:str, _min:int, _max:int) -> str:
|
||||
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
|
|
|
@ -74,8 +74,8 @@ class ClangProgram:
|
|||
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
|
||||
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
|
||||
else:
|
||||
self.fxn(*[x._buf for x in args])
|
||||
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
||||
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||
|
|
|
@ -80,7 +80,7 @@ class CLProgram:
|
|||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
|
||||
if wait:
|
||||
e.wait()
|
||||
|
@ -91,7 +91,7 @@ class CLProgram:
|
|||
return None
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, functools
|
||||
import os, subprocess, pathlib, functools, ctypes
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
|
@ -65,7 +65,10 @@ class MetalProgram:
|
|||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
||||
for i,a in enumerate(bufs):
|
||||
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
||||
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
|
||||
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
|
@ -75,7 +78,7 @@ class MetalProgram:
|
|||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
|
||||
|
|
|
@ -59,7 +59,7 @@ class View(ViewInternal):
|
|||
# generate an expression if you have a single idx variable
|
||||
def expr_node(self, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
|
||||
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
|
||||
ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
|
||||
acc = 1
|
||||
for d,s in reversed(self.shape_strides):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
|
@ -69,7 +69,7 @@ class View(ViewInternal):
|
|||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(self, idxs) -> Node:
|
||||
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
||||
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
|
@ -162,7 +162,7 @@ class ShapeTracker:
|
|||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
|
|
|
@ -65,6 +65,7 @@ class Node:
|
|||
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if self == b: return NumNode(1)
|
||||
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
|
@ -262,6 +263,14 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
|||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(n, (int, NumNode)): return int(n)
|
||||
if isinstance(n, Variable): return var_vals[n]
|
||||
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
|
||||
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
|
||||
raise NotImplementedError(n)
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
|
|
Loading…
Reference in New Issue