1
0
Fork 0

symbolic codegen and exec (#1552)

* symbolic codegen and exec

* fix and add test

* no sketchy

* merge_dicts type

* dtypes._arg_int32
pull/1512/head^2
chenyu 2023-08-16 14:43:41 -07:00 committed by GitHub
parent 1e1d48b4e6
commit 11dd9b1741
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 201 additions and 43 deletions

View File

@ -188,6 +188,8 @@ jobs:
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_ops.py
- name: Run JIT test
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
- name: Run symbolic shapetracker test
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
#- name: Run webgpu pytest

View File

@ -1,5 +1,5 @@
import unittest
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import Context, ContextVar, merge_dicts
VARIABLE = ContextVar("VARIABLE", 0)
@ -106,5 +106,17 @@ with Context(VARIABLE=1):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
class TestMergeDicts(unittest.TestCase):
def test_merge_dicts(self):
a = {"a": 1, "b": 2}
b = {"a": 1, "c": 3}
c = {}
d = {"a": 2, "b": 2}
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
assert merge_dicts([a, c]) == a
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
with self.assertRaises(AssertionError):
merge_dicts([a, d])
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,113 @@
import unittest
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported")
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_add(self):
def f(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_matmul_same_var_different_val(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4)
b = Tensor.rand(7, 5)
with self.assertRaises(AssertionError):
f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
expected = f(q, k, v).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@ -26,17 +26,18 @@ class TestSymbolic(unittest.TestCase):
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)
class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):

View File

@ -283,6 +283,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert i // i == 1
def test_node_mod_node(self):
i = Variable("i", 1, 10)

View File

@ -9,7 +9,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
VariableOrNum = Union[Variable, NumNode, Node]
# bottom ones are asm only
@ -301,6 +301,9 @@ class Linearizer:
# add global buffers
for buf,name in self.arg_bufs.items():
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
@ -317,7 +320,7 @@ class Linearizer:
if DEBUG >= 3: self.printbufs()
# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# parse AST
@ -548,7 +551,7 @@ class Linearizer:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)

View File

@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
# early exit
return
if k.opts.has_local:
if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer):
while prod(k.sts[0].shape[:k.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if len(xb_choices):
xb_choices = sorted(xb_choices)
@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer):
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))):
if (s:=k.full_unupcasted_shape[-1]) <= 32:
if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
k.upcast()
# if it's small, upcast a second reduce dimension too
if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast()

View File

@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib
from weakref import KeyedRef, ref
from _weakref import _remove_dead_weakref # type: ignore
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
from math import prod # noqa: F401 # pylint:disable=unused-import
ShapeType = Tuple[int, ...]
@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def merge_dicts(ds:Iterable[Dict]) -> Dict:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for k,v in kvs}
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
@ -115,6 +119,7 @@ class dtypes:
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

View File

@ -2,8 +2,9 @@ from __future__ import annotations
import functools, time
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
@ -131,20 +132,24 @@ class ASTRunner:
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
return self
def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]:
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs, force_wait=force_wait)
return self(rawbufs, var_vals, force_wait=force_wait)
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None,
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
(local_size + [1]*(3-len(local_size))) if local_size is not None else None,
*rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += self.mem_estimate
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
return et
@ -178,9 +183,11 @@ class Compiled:
output.realized = None
break
# we don't have an output buffer, we have to create it
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
# update the output var_vals from src
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)
@ -200,5 +207,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(k.bufs)
prg.exec(k.bufs, var_vals=output.st.var_vals)
return output.realized

View File

@ -3,7 +3,7 @@ import math
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
# div is different in cl than python
render_cl = render_python.copy()
@ -17,6 +17,7 @@ class CStyleLanguage(NamedTuple):
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_prefix: str = ""
arg_int_prefix: str = ""
barrier: str = ""
gid: List[str] = []
lid: List[str] = []
@ -69,7 +70,7 @@ class CStyleLanguage(NamedTuple):
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:int, _max:int) -> str:
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
@ -78,6 +79,7 @@ class CStyleLanguage(NamedTuple):
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
@ -128,7 +130,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)

View File

@ -38,7 +38,7 @@ class WGSLLanguage(CStyleLanguage):
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg, global_size[::-1] if len(global_size) else [1], local_size
def render_for(self, expr:str, _min:int, _max:int) -> str:
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:

View File

@ -74,8 +74,8 @@ class ClangProgram:
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf for x in args])
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.monotonic()-st
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@ -80,7 +80,7 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
if wait:
e.wait()
@ -91,7 +91,7 @@ class CLProgram:
return None
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))

View File

@ -1,5 +1,5 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, functools
import os, subprocess, pathlib, functools, ctypes
import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any
from tinygrad.codegen.linearizer import LinearizerOptions
@ -65,7 +65,10 @@ class MetalProgram:
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
for i,a in enumerate(bufs):
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
@ -75,7 +78,7 @@ class MetalProgram:
METAL.mtl_buffers_in_flight.append(command_buffer)
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))

View File

@ -59,7 +59,7 @@ class View(ViewInternal):
# generate an expression if you have a single idx variable
def expr_node(self, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
acc = 1
for d,s in reversed(self.shape_strides):
ret.append(((idx//acc)%d)*s)
@ -69,7 +69,7 @@ class View(ViewInternal):
# generate an expression if you have a variable or expression for each index
def expr_idxs(self, idxs) -> Node:
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
@ -162,7 +162,7 @@ class ShapeTracker:
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
ret[idxs.index(this_dim)] = 1

View File

@ -65,6 +65,7 @@ class Node:
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if self == b: return NumNode(1)
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
@ -262,6 +263,14 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
if isinstance(n, (int, NumNode)): return int(n)
if isinstance(n, Variable): return var_vals[n]
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
raise NotImplementedError(n)
@functools.lru_cache(maxsize=None)
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
render_python: Dict[Type, Callable] = {