metal indirect command buffers (#2285)
* metal indirect command buffers * sub 1ms gpt * metal batch exec is good * remove whitespace * input_replace * fix ci * useResources * very simple cacheallocator * update_stats * fix CI * minor * remove that from jitpull/2291/head
parent
d86ea188dd
commit
b1f7f29525
|
@ -42,9 +42,9 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
|||
|
||||
# hack to put the inputs back
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
realized_input = args[idx[0]].lazydata.realized
|
||||
realized_input = args[idx].lazydata.realized
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx[0]}'
|
||||
special_names[id(realized_input)] = f'input{idx}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
for i, output in enumerate(the_output):
|
||||
|
|
|
@ -63,7 +63,7 @@ def compile(dat, output_fn):
|
|||
|
||||
# pull out inputs and put them in the jit cache
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
|
||||
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
|
||||
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
|
||||
|
||||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
|
|
|
@ -30,7 +30,6 @@ def run():
|
|||
|
||||
# reset jit
|
||||
allreduce_jit.cnt = 0
|
||||
allreduce_jit.input_replace = {}
|
||||
|
||||
# test uneven chunk sizes
|
||||
for _ in range(3):
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
|
||||
from dataclasses import dataclass
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
|
@ -16,13 +16,54 @@ class JitItem:
|
|||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
self.jit_fxn: Optional[BatchExecutor] = None
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[JitItem] = []
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]] = {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
|
||||
self.expected_vals: Optional[Tuple[Variable, ...]] = None
|
||||
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
|
||||
|
||||
@property
|
||||
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
|
||||
@property
|
||||
def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {}
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
@ -32,39 +73,36 @@ class TinyJit:
|
|||
|
||||
# all inputs are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
|
||||
|
||||
# get rawbuffers
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {k:(cast(RawBuffer, v.lazydata.realized), v.lazydata.st) for k,v in input_tensors.items()}
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
|
||||
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
expected_vals = tuple(var_vals.keys())
|
||||
|
||||
if self.cnt >= 2:
|
||||
# check validity and assign the inputs
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
|
||||
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name][0]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
assert self.expected_vals == expected_vals, "mismatch of var_vals"
|
||||
assert self.expected_sts_dtype == expected_sts_dtype, "mismatch of sts"
|
||||
assert self.jit_fxn, "didn't get jitted?"
|
||||
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
|
||||
elif self.cnt == 1:
|
||||
self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype
|
||||
|
||||
CacheCollector.start(var_vals)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
# get the inputs for replacement
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v[0] for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
jit_cache = CacheCollector.finish()
|
||||
assert len(jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
alt_batch_exec = Device[Device.DEFAULT].batch_executor
|
||||
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
|
||||
elif self.cnt == 0:
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
# clear the inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
self.cnt += 1
|
||||
return self.ret
|
||||
|
||||
|
|
|
@ -112,6 +112,7 @@ class Interpreted:
|
|||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = None
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
|
@ -232,8 +233,8 @@ class ASTRunner:
|
|||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
|
||||
def to_program(self, k):
|
||||
|
|
|
@ -38,7 +38,7 @@ class RawBufferCopyIn(RawBuffer):
|
|||
class RawBufferMapped(RawBufferCopyIn):
|
||||
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
|
||||
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
|
||||
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size)
|
||||
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
|
||||
def toCPU(self) -> np.ndarray: return self.buffer_view().copy() # Need a copy, since jit will write to the same buffer.
|
||||
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))
|
||||
|
||||
|
@ -83,8 +83,8 @@ class LRUAllocator:
|
|||
|
||||
def _alloc_buffer(self, size, dtype, device, **kwargs):
|
||||
self.ensure_has_free_space(size*dtype.itemsize, device)
|
||||
while True:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
|
||||
break
|
||||
except Exception:
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, Cocoa, libdispatch
|
||||
from typing import List, Any, Tuple
|
||||
from typing import List, Any, Tuple, Dict, Union, Set
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
|
@ -44,7 +44,7 @@ def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
|
|||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
options = Metal.MTLCompileOptions.new()
|
||||
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
# TODO: avoid file write here?
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
|
@ -80,4 +80,75 @@ class MetalProgram:
|
|||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize)
|
||||
from tinygrad.jit import BatchExecutor, JitItem
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
class MetalBatchExecutor(BatchExecutor):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
|
||||
icb_descriptor.setInheritBuffers_(False)
|
||||
icb_descriptor.setInheritPipelineState_(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
||||
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
|
||||
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"
|
||||
|
||||
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
|
||||
self.input_has_variable_dims: Set[int] = set()
|
||||
read_resources, write_resources = [], []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(ji.prg.clprg.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
icb_command.setComputePipelineState_(pipeline_state)
|
||||
for i,b in enumerate(ji.rawbufs):
|
||||
if b is not None:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
||||
if i == 0: write_resources.append(b._buf)
|
||||
else: read_resources.append(b._buf)
|
||||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(getattr(ji.prg,"vars",[])):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT"
|
||||
if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size):
|
||||
self.input_has_variable_dims.add(j)
|
||||
else:
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
|
||||
self.command_buffer: Any = None
|
||||
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
||||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
|
||||
for (j,i),input_name in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
|
||||
for j in self.input_has_variable_dims:
|
||||
global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
self.int_buf_view[:] = list(var_vals.values())
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
||||
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
|
||||
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
self.command_buffer = command_buffer
|
||||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
et = None
|
||||
super().update_stats(var_vals, et)
|
||||
return et
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)
|
||||
|
|
Loading…
Reference in New Issue