jit doesn't use named tensors (#2393)
* jit doesn't use named tensors * move to compile2 * remove broken single root junk * explicit float32 * skip slow testpull/2396/head
parent
80e4ad8bf5
commit
8656eebb42
|
@ -178,17 +178,14 @@ jobs:
|
|||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
|
||||
DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot model correctness (float32, new compiler)
|
||||
run: DEBUGCL=1 FLOAT16=0 python3 openpilot/compile2.py
|
||||
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test tensor core ops
|
||||
run: GPU=1 TC=2 python -m pytest -n=auto test/test_ops.py
|
||||
|
|
|
@ -1,150 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, time, io, pathlib, sys, traceback, re
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if os.getenv("OPT", None) is None:
|
||||
os.environ['OPT'] = '99'
|
||||
if os.getenv("GPU", None) is None:
|
||||
os.environ['GPU'] = '1'
|
||||
if os.getenv("IMAGE", None) is None:
|
||||
os.environ['IMAGE'] = '2'
|
||||
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
|
||||
import onnx
|
||||
import numpy as np
|
||||
|
||||
import tinygrad.graph as graph
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.jit import TinyJit, CacheCollector
|
||||
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
from extra.utils import fetch
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
np.random.seed(1337)
|
||||
def get_random_input_tensors(input_shapes):
|
||||
# this 16 is a random scale factor
|
||||
inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
|
||||
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
|
||||
return inputs, np_inputs
|
||||
|
||||
@TinyJit
|
||||
def model_exec(run_onnx, using_graph, **inputs):
|
||||
ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32)
|
||||
GlobalCounters.reset()
|
||||
CacheCollector.start() # don't cache pre-realize
|
||||
if using_graph: graph.GRAPH = True
|
||||
print("realizing")
|
||||
return ret.realize()
|
||||
|
||||
def compile(dat, output_fn):
|
||||
Tensor.manual_seed(1337)
|
||||
Tensor.no_grad = True
|
||||
using_graph = graph.GRAPH
|
||||
if getenv("GRAPH") < 3: graph.GRAPH = False
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
inputs, np_inputs = get_random_input_tensors(input_shapes)
|
||||
# run twice to trigger the JIT
|
||||
for i in range(2): tinygrad_out = model_exec(run_onnx, i == 1 and using_graph, **inputs)
|
||||
graph.GRAPH = False
|
||||
print("kernel count:", len(model_exec.jit_cache))
|
||||
assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!"
|
||||
|
||||
# pull out inputs and put them in the jit cache
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
|
||||
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
|
||||
|
||||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
cl_cache = []
|
||||
for ji in model_exec.jit_cache:
|
||||
prg = ji.prg
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
||||
if getenv("VALIDTEST") == 1:
|
||||
src = re.search(r"=.*\?.*?read_image", prg.prg)
|
||||
if src is not None: raise Exception("Openpilot has valid checks!")
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in ji.rawbufs]]))
|
||||
used_ops += prg.op_estimate
|
||||
|
||||
from extra.thneed import Thneed
|
||||
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
|
||||
|
||||
# save thneed (before run)
|
||||
t.save(output_fn)
|
||||
|
||||
print(f"buffers to save: {len(t.buffers_to_save)}, inputs: {list(t.inputs.keys())}, outputs: {t.outputs}")
|
||||
runtime = t.run()
|
||||
print(f"network using {used_ops/1e9:.2f} GOPS with runtime {runtime*1e3:.2f} ms that's {used_ops/runtime*1e-9:.2f} GFLOPS")
|
||||
|
||||
# confirm thneed found the right output
|
||||
thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
cl.enqueue_copy(CL.cl_queue[0], thneed_out, t.outputs[0], is_blocking=True)
|
||||
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
||||
|
||||
# testing is float32 only (fix this)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
if FLOAT16 == 0:
|
||||
try:
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
|
||||
print(thneed_out, torch_out, "mse", np.sum((thneed_out-torch_out)**2), "max err", np.max(np.abs((thneed_out-torch_out))))
|
||||
np.testing.assert_allclose(torch_out, thneed_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
# test loading/run thneed
|
||||
_, new_np_inputs = get_random_input_tensors(input_shapes)
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
|
||||
# try old thneed with a different input
|
||||
for k,v in t.inputs.items():
|
||||
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
t.run()
|
||||
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
cl.enqueue_copy(CL.cl_queue[0], old_thneed_out, t.outputs[0], is_blocking=True)
|
||||
|
||||
# compare thneed (rerun) with torch
|
||||
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
# load thneed and try that
|
||||
_, new_np_inputs = get_random_input_tensors(input_shapes)
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
nt = Thneed()
|
||||
nt.load(output_fn)
|
||||
|
||||
# inputs
|
||||
for k,v in nt.inputs.items():
|
||||
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
|
||||
# compare torch to thneed
|
||||
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
print("thneed self-test passed!")
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"TEST NOT HAPPENING {e}")
|
||||
|
||||
|
||||
# UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 python3 openpilot/compile.py
|
||||
# 22.59 ms
|
||||
if __name__ == "__main__":
|
||||
dat = fetch(OPENPILOT_MODEL if len(sys.argv) == 1 else sys.argv[1])
|
||||
compile(dat, sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed")
|
|
@ -1,12 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, sys, io, pathlib
|
||||
import os, sys, io, pathlib, re
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "OPT" not in os.environ: os.environ["OPT"] = "99"
|
||||
os.environ["PREREALIZE"] = "0"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
|
@ -55,6 +54,9 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
|||
def schedule_to_thneed(schedule, output_fn):
|
||||
from extra.thneed import Thneed
|
||||
|
||||
print("kernel count:", len(schedule))
|
||||
assert len(schedule) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
|
||||
|
||||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
cl_cache = []
|
||||
|
@ -66,6 +68,10 @@ def schedule_to_thneed(schedule, output_fn):
|
|||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
||||
if getenv("VALIDTEST") == 1:
|
||||
src = re.search(r"=.*\?.*?read_image", prg.prg)
|
||||
if src is not None: raise Exception("Openpilot has valid checks!")
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
#!/bin/bash
|
||||
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py
|
||||
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile2.py
|
||||
|
|
|
@ -496,6 +496,7 @@ class TestLinearizerOpts(unittest.TestCase):
|
|||
|
||||
def test_padto_matmul(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
|
||||
if Device.DEFAULT == "CUDA": self.skipTest("super slow on CUDA/triton")
|
||||
N = 17 * 17
|
||||
Tensor.manual_seed(289)
|
||||
a = Tensor.rand(N, N)
|
||||
|
|
|
@ -3,8 +3,6 @@ from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, fla
|
|||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
from tinygrad.lazy import get_single_root
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
|
@ -60,7 +58,6 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
|||
from typing import Dict, List
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
|
||||
|
@ -107,7 +107,7 @@ def _tree(lazydata, prefix=""):
|
|||
|
||||
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
|
||||
|
||||
def graph_uops(uops):
|
||||
def graph_uops(uops:List[UOp]):
|
||||
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
||||
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
||||
|
|
|
@ -20,12 +20,12 @@ class TinyJit(Generic[ReturnType]):
|
|||
self.cnt: int = 0
|
||||
self.ret: Optional[ReturnType] = None
|
||||
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
||||
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
|
||||
self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
|
||||
|
||||
@property
|
||||
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
|
||||
@property
|
||||
def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {}
|
||||
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {}
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
@ -33,11 +33,11 @@ class TinyJit(Generic[ReturnType]):
|
|||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
|
||||
expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()])
|
||||
|
||||
# get rawbuffers
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
input_rawbuffers: List[RawBuffer] = [cast(RawBuffer, v.lazydata.realized) for v in input_tensors.values()]
|
||||
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
|
@ -45,11 +45,11 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
if self.cnt >= 2:
|
||||
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
||||
assert self.expected_sts_dtype == expected_sts_dtype, f"mismatch of sts, expected {self.expected_sts_dtype} got {expected_sts_dtype}"
|
||||
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
|
||||
assert self.jit_fxn, "didn't get jitted?"
|
||||
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
|
||||
elif self.cnt == 1:
|
||||
self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype
|
||||
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
|
||||
|
||||
CacheCollector.start(var_vals)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
|
|
@ -76,7 +76,6 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
|||
|
||||
# **** lazy operations ****
|
||||
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class MemBuffer:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: Any
|
||||
val: Union[int, float]
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
|
@ -152,22 +152,22 @@ class JitItem:
|
|||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
if a in input_rawbuffers:
|
||||
self.input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, libdispatch
|
||||
from typing import List, Any, Tuple, Dict, Union, Set, cast
|
||||
from typing import List, Any, Tuple, Dict, Set, cast
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
|
||||
|
@ -85,7 +85,7 @@ class MetalProgram:
|
|||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalBatchExecutor(BatchExecutor):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# create metal batch exec
|
||||
|
@ -127,12 +127,12 @@ class MetalBatchExecutor(BatchExecutor):
|
|||
self.command_buffer: Any = None
|
||||
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
||||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
|
||||
for (j,i),input_name in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
|
||||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.input_has_variable_dims:
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
|
|
Loading…
Reference in New Issue