1
0
Fork 0

metal indirect command buffers (#2285)

* metal indirect command buffers

* sub 1ms gpt

* metal batch exec is good

* remove whitespace

* input_replace

* fix ci

* useResources

* very simple cacheallocator

* update_stats

* fix CI

* minor

* remove that from jit
pull/2291/head
George Hotz 2023-11-13 17:58:26 -08:00 committed by GitHub
parent d86ea188dd
commit b1f7f29525
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 36 deletions

View File

@ -42,9 +42,9 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx[0]].lazydata.realized
realized_input = args[idx].lazydata.realized
run.jit_cache[j].rawbufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx[0]}'
special_names[id(realized_input)] = f'input{idx}'
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
for i, output in enumerate(the_output):

View File

@ -63,7 +63,7 @@ def compile(dat, output_fn):
# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
# transform to CL.CACHE
used_ops = 0

View File

@ -30,7 +30,6 @@ def run():
# reset jit
allreduce_jit.cnt = 0
allreduce_jit.input_replace = {}
# test uneven chunk sizes
for _ in range(3):

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
from tinygrad.ops import RawBuffer, Device, ASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
from dataclasses import dataclass
from weakref import ref, WeakKeyDictionary
@ -16,13 +16,54 @@ class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.jit_fxn: Optional[BatchExecutor] = None
self.cnt: int = 0
self.jit_cache: List[JitItem] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]] = {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
@property
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
@property
def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {}
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
@ -32,39 +73,36 @@ class TinyJit:
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
# get rawbuffers
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {k:(cast(RawBuffer, v.lazydata.realized), v.lazydata.st) for k,v in input_tensors.items()}
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
expected_vals = tuple(var_vals.keys())
if self.cnt >= 2:
# check validity and assign the inputs
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name][0]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_sts_dtype == expected_sts_dtype, "mismatch of sts"
assert self.jit_fxn, "didn't get jitted?"
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
elif self.cnt == 1:
self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j,ji in enumerate(self.jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
jit_cache = CacheCollector.finish()
assert len(jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
alt_batch_exec = Device[Device.DEFAULT].batch_executor
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
# clear the inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return self.ret

View File

@ -112,6 +112,7 @@ class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = None
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
@ -232,8 +233,8 @@ class ASTRunner:
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
self.method_cache: Dict[LazyOp, ASTRunner] = {}
def to_program(self, k):

View File

@ -38,7 +38,7 @@ class RawBufferCopyIn(RawBuffer):
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size)
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
def toCPU(self) -> np.ndarray: return self.buffer_view().copy() # Need a copy, since jit will write to the same buffer.
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))
@ -83,8 +83,8 @@ class LRUAllocator:
def _alloc_buffer(self, size, dtype, device, **kwargs):
self.ensure_has_free_space(size*dtype.itemsize, device)
while True:
try:
while True:
try:
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
break
except Exception:

View File

@ -1,12 +1,12 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch
from typing import List, Any, Tuple
from typing import List, Any, Tuple, Dict, Union, Set
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
from tinygrad.ops import Compiled
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
@ -44,7 +44,7 @@ def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
options = Metal.MTLCompileOptions.new()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
@ -80,4 +80,75 @@ class MetalProgram:
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize)
from tinygrad.jit import BatchExecutor, JitItem
from tinygrad.shape.symbolic import Variable, Node
class MetalBatchExecutor(BatchExecutor):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
self.input_has_variable_dims: Set[int] = set()
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(ji.prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
for i,b in enumerate(ji.rawbufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(getattr(ji.prg,"vars",[])):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = ji.prg.launch_dims(var_vals)
assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size):
self.input_has_variable_dims.add(j)
else:
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
for (j,i),input_name in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
for j in self.input_has_variable_dims:
global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer
if wait:
command_buffer.waitUntilCompleted()
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
et = None
super().update_stats(var_vals, et)
return et
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)