2023-03-04 00:22:15 -07:00
|
|
|
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
2023-11-01 19:44:00 -06:00
|
|
|
import os, subprocess, pathlib, ctypes, tempfile
|
2023-11-12 12:04:20 -07:00
|
|
|
import Metal, Cocoa, libdispatch
|
2023-11-15 12:13:38 -07:00
|
|
|
from typing import List, Any, Tuple, Dict, Union, Set, cast
|
2023-08-31 15:42:09 -06:00
|
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
2023-11-15 12:37:28 -07:00
|
|
|
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
|
2023-11-15 14:34:52 -07:00
|
|
|
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
|
2023-10-10 08:46:41 -06:00
|
|
|
from tinygrad.renderer.metal import MetalRenderer
|
2023-11-13 18:58:26 -07:00
|
|
|
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
2023-11-14 09:08:51 -07:00
|
|
|
from tinygrad.shape.symbolic import Variable, Node
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-08-17 11:33:32 -06:00
|
|
|
class MetalAllocator(LRUAllocator):
|
2023-11-16 21:14:16 -07:00
|
|
|
def _do_alloc(self, size, dtype, device, **kwargs):
|
|
|
|
buf_len, max_buf_len = size*dtype.itemsize, METAL.device.maxBufferLength()
|
|
|
|
assert buf_len < max_buf_len, f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB."
|
|
|
|
buf = METAL.device.newBufferWithLength_options_(buf_len, Metal.MTLResourceStorageModeShared)
|
|
|
|
assert buf, f"Metal buffer allocation failed with {buf}."
|
|
|
|
return buf
|
2023-08-17 11:33:32 -06:00
|
|
|
def _do_free(self, buf): buf.release()
|
|
|
|
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
|
|
|
|
2023-03-01 19:57:29 -07:00
|
|
|
class _METAL:
|
2023-03-12 12:01:25 -06:00
|
|
|
def __init__(self):
|
2023-08-01 13:10:20 -06:00
|
|
|
self.mtl_buffers_in_flight: List[Any] = []
|
2023-03-12 12:01:25 -06:00
|
|
|
self.device = Metal.MTLCreateSystemDefaultDevice()
|
2023-11-15 12:37:28 -07:00
|
|
|
self.supports_icb = (self.device.supportsFamily_(Metal.MTLGPUFamilyMac2) or self.device.supportsFamily_(Metal.MTLGPUFamilyApple3) or self.device.supportsFamily_(Metal.MTLGPUFamilyCommon2)) and self.device.argumentBuffersSupport() is Metal.MTLArgumentBuffersTier2
|
2023-08-24 16:42:00 -06:00
|
|
|
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
2023-08-17 11:33:32 -06:00
|
|
|
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
|
2023-08-01 13:10:20 -06:00
|
|
|
# TODO: is there a better way to do this?
|
2023-03-24 11:24:27 -06:00
|
|
|
def synchronize(self):
|
2023-08-01 13:10:20 -06:00
|
|
|
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
|
|
|
self.mtl_buffers_in_flight.clear()
|
2023-03-01 19:57:29 -07:00
|
|
|
METAL = _METAL()
|
|
|
|
|
2023-03-10 22:57:05 -07:00
|
|
|
class RawMetalBuffer(RawBufferMapped):
|
2023-06-13 22:31:31 -06:00
|
|
|
def __init__(self, size:int, dtype:DType):
|
2023-08-15 19:21:08 -06:00
|
|
|
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
|
2023-08-17 11:33:32 -06:00
|
|
|
super().__init__(size, dtype, allocator=METAL.allocator)
|
2023-03-10 22:57:05 -07:00
|
|
|
def _buffer(self):
|
2023-03-24 11:24:27 -06:00
|
|
|
METAL.synchronize()
|
2023-03-18 15:40:23 -06:00
|
|
|
return self._buf.contents().as_buffer(self._buf.length())
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-03-04 00:14:40 -07:00
|
|
|
def unwrap(x):
|
|
|
|
ret, err = x
|
|
|
|
assert err is None, str(err)
|
|
|
|
return ret
|
|
|
|
|
2023-11-02 00:01:32 -06:00
|
|
|
@diskcache
|
|
|
|
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
|
|
|
|
if use_xcode:
|
|
|
|
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
|
|
|
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
|
|
|
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
2023-11-13 18:58:26 -07:00
|
|
|
options = Metal.MTLCompileOptions.new()
|
2023-11-02 00:01:32 -06:00
|
|
|
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
|
|
|
# TODO: avoid file write here?
|
|
|
|
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
2023-11-05 22:02:31 -07:00
|
|
|
unwrap(library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None))
|
2023-11-02 00:01:32 -06:00
|
|
|
return pathlib.Path(output_file.name).read_bytes()
|
|
|
|
|
2023-03-01 19:57:29 -07:00
|
|
|
class MetalProgram:
|
2023-11-02 00:01:32 -06:00
|
|
|
def __init__(self, name:str, lib:bytes):
|
2023-11-01 19:44:00 -06:00
|
|
|
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
|
|
|
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
2023-03-01 19:57:29 -07:00
|
|
|
self.fxn = self.library.newFunctionWithName_(name)
|
2023-11-10 09:17:10 -07:00
|
|
|
if DEBUG >= 6:
|
2023-11-01 19:44:00 -06:00
|
|
|
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
|
|
|
shader.write(lib)
|
|
|
|
shader.flush()
|
|
|
|
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
2023-03-04 00:14:40 -07:00
|
|
|
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-11-03 13:31:29 -06:00
|
|
|
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
|
2023-03-01 19:57:29 -07:00
|
|
|
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
2023-08-01 13:10:20 -06:00
|
|
|
command_buffer = METAL.mtl_queue.commandBuffer()
|
2023-03-01 19:57:29 -07:00
|
|
|
encoder = command_buffer.computeCommandEncoder()
|
|
|
|
encoder.setComputePipelineState_(self.pipeline_state)
|
2023-08-16 15:43:41 -06:00
|
|
|
for i,a in enumerate(bufs):
|
|
|
|
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
|
|
|
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
|
|
|
|
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
|
2023-06-21 12:50:43 -06:00
|
|
|
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
2023-03-01 19:57:29 -07:00
|
|
|
encoder.endEncoding()
|
|
|
|
command_buffer.commit()
|
|
|
|
if wait:
|
|
|
|
command_buffer.waitUntilCompleted()
|
|
|
|
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
2023-08-01 13:10:20 -06:00
|
|
|
METAL.mtl_buffers_in_flight.append(command_buffer)
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-11-13 18:58:26 -07:00
|
|
|
class MetalBatchExecutor(BatchExecutor):
|
|
|
|
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
|
|
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
|
|
|
|
|
|
|
# create metal batch exec
|
|
|
|
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
|
|
|
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
|
|
|
|
icb_descriptor.setInheritBuffers_(False)
|
|
|
|
icb_descriptor.setInheritPipelineState_(False)
|
|
|
|
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
|
|
|
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
|
|
|
|
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"
|
|
|
|
|
|
|
|
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
|
|
|
|
self.input_has_variable_dims: Set[int] = set()
|
|
|
|
read_resources, write_resources = [], []
|
|
|
|
for j,ji in enumerate(self.jit_cache):
|
2023-11-15 12:13:38 -07:00
|
|
|
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
2023-11-13 18:58:26 -07:00
|
|
|
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
2023-11-15 12:13:38 -07:00
|
|
|
descriptor.setComputeFunction_(prg.clprg.fxn)
|
2023-11-13 18:58:26 -07:00
|
|
|
descriptor.setSupportIndirectCommandBuffers_(True)
|
|
|
|
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
|
|
|
|
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
|
|
|
icb_command.setComputePipelineState_(pipeline_state)
|
|
|
|
for i,b in enumerate(ji.rawbufs):
|
|
|
|
if b is not None:
|
|
|
|
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
|
|
|
if i == 0: write_resources.append(b._buf)
|
|
|
|
else: read_resources.append(b._buf)
|
|
|
|
var_vals_keys = list(var_vals.keys())
|
2023-11-15 12:13:38 -07:00
|
|
|
for i,v in enumerate(prg.vars):
|
2023-11-13 18:58:26 -07:00
|
|
|
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
2023-11-15 12:13:38 -07:00
|
|
|
global_size, local_size = prg.launch_dims(var_vals)
|
|
|
|
assert prg.global_size and prg.local_size, "need global and local size to JIT"
|
|
|
|
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
|
2023-11-13 18:58:26 -07:00
|
|
|
self.input_has_variable_dims.add(j)
|
|
|
|
else:
|
|
|
|
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
|
|
|
icb_command.setBarrier()
|
|
|
|
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
|
|
|
|
self.command_buffer: Any = None
|
|
|
|
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
|
|
|
|
|
|
|
|
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
|
|
|
# NOTE: you at least can't update the ints if this is running
|
|
|
|
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
|
|
|
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
|
|
|
|
for (j,i),input_name in self.input_replace.items():
|
|
|
|
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
|
|
|
|
for j in self.input_has_variable_dims:
|
2023-11-15 12:13:38 -07:00
|
|
|
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
2023-11-13 18:58:26 -07:00
|
|
|
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
|
|
|
self.int_buf_view[:] = list(var_vals.values())
|
|
|
|
command_buffer = METAL.mtl_queue.commandBuffer()
|
|
|
|
encoder = command_buffer.computeCommandEncoder()
|
|
|
|
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
|
|
|
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
|
|
|
|
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
|
|
|
|
encoder.endEncoding()
|
|
|
|
command_buffer.commit()
|
|
|
|
self.command_buffer = command_buffer
|
|
|
|
if wait:
|
|
|
|
command_buffer.waitUntilCompleted()
|
|
|
|
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
|
|
|
else:
|
|
|
|
METAL.mtl_buffers_in_flight.append(command_buffer)
|
|
|
|
et = None
|
2023-11-15 14:34:52 -07:00
|
|
|
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache))
|
2023-11-13 18:58:26 -07:00
|
|
|
return et
|
|
|
|
|
2023-11-15 12:37:28 -07:00
|
|
|
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)
|