2023-03-04 00:22:15 -07:00
|
|
|
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
2023-08-16 15:43:41 -06:00
|
|
|
import os, subprocess, pathlib, functools, ctypes
|
2023-03-01 19:57:29 -07:00
|
|
|
import Metal, Cocoa, libdispatch # type: ignore
|
2023-08-01 13:10:20 -06:00
|
|
|
from typing import List, Any
|
2023-08-31 15:42:09 -06:00
|
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
2023-08-05 12:07:04 -06:00
|
|
|
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
2023-08-15 19:21:08 -06:00
|
|
|
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
|
2023-03-18 15:40:23 -06:00
|
|
|
from tinygrad.ops import Compiled
|
2023-08-17 11:33:32 -06:00
|
|
|
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
|
2023-03-01 19:57:29 -07:00
|
|
|
|
|
|
|
METAL_XCODE = getenv("METAL_XCODE")
|
|
|
|
|
2023-08-17 11:33:32 -06:00
|
|
|
class MetalAllocator(LRUAllocator):
|
|
|
|
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
|
|
|
def _do_free(self, buf): buf.release()
|
|
|
|
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
|
|
|
|
2023-03-01 19:57:29 -07:00
|
|
|
class _METAL:
|
2023-03-12 12:01:25 -06:00
|
|
|
def __init__(self):
|
2023-08-01 13:10:20 -06:00
|
|
|
self.mtl_buffers_in_flight: List[Any] = []
|
2023-03-12 12:01:25 -06:00
|
|
|
self.device = Metal.MTLCreateSystemDefaultDevice()
|
2023-08-24 16:42:00 -06:00
|
|
|
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
2023-08-17 11:33:32 -06:00
|
|
|
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
|
2023-08-01 13:10:20 -06:00
|
|
|
# TODO: is there a better way to do this?
|
2023-03-24 11:24:27 -06:00
|
|
|
def synchronize(self):
|
2023-08-01 13:10:20 -06:00
|
|
|
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
|
|
|
self.mtl_buffers_in_flight.clear()
|
2023-03-01 19:57:29 -07:00
|
|
|
METAL = _METAL()
|
|
|
|
|
2023-03-10 22:57:05 -07:00
|
|
|
class RawMetalBuffer(RawBufferMapped):
|
2023-06-13 22:31:31 -06:00
|
|
|
def __init__(self, size:int, dtype:DType):
|
2023-08-15 19:21:08 -06:00
|
|
|
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
|
2023-08-17 11:33:32 -06:00
|
|
|
super().__init__(size, dtype, allocator=METAL.allocator)
|
2023-03-10 22:57:05 -07:00
|
|
|
def _buffer(self):
|
2023-03-24 11:24:27 -06:00
|
|
|
METAL.synchronize()
|
2023-03-18 15:40:23 -06:00
|
|
|
return self._buf.contents().as_buffer(self._buf.length())
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-03-04 00:14:40 -07:00
|
|
|
def unwrap(x):
|
|
|
|
ret, err = x
|
|
|
|
assert err is None, str(err)
|
|
|
|
return ret
|
|
|
|
|
2023-03-01 19:57:29 -07:00
|
|
|
class MetalProgram:
|
2023-08-14 20:29:30 -06:00
|
|
|
def __init__(self, name:str, prg:str, binary:bool=False):
|
2023-03-01 19:57:29 -07:00
|
|
|
if METAL_XCODE:
|
|
|
|
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
2023-03-04 00:14:40 -07:00
|
|
|
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
2023-03-01 19:57:29 -07:00
|
|
|
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
|
|
|
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
2023-03-04 00:14:40 -07:00
|
|
|
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
2023-03-01 19:57:29 -07:00
|
|
|
else:
|
|
|
|
options = Metal.MTLCompileOptions.alloc().init()
|
2023-03-04 00:14:40 -07:00
|
|
|
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
2023-03-01 19:57:29 -07:00
|
|
|
self.fxn = self.library.newFunctionWithName_(name)
|
|
|
|
# hacks to disassemble shader
|
|
|
|
if DEBUG >= 5:
|
2023-03-04 00:14:40 -07:00
|
|
|
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None))
|
2023-03-01 19:57:29 -07:00
|
|
|
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
|
|
|
|
desc.setComputeFunction_(self.fxn)
|
2023-03-04 00:14:40 -07:00
|
|
|
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None))
|
|
|
|
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
|
2023-03-05 12:21:12 -07:00
|
|
|
# clone https://github.com/dougallj/applegpu.git in tinygrad/disassemblers
|
2023-08-30 11:41:08 -06:00
|
|
|
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
|
2023-03-04 00:14:40 -07:00
|
|
|
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
2023-03-01 19:57:29 -07:00
|
|
|
|
|
|
|
def __call__(self, global_size, local_size, *bufs, wait=False):
|
|
|
|
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
2023-08-01 13:10:20 -06:00
|
|
|
command_buffer = METAL.mtl_queue.commandBuffer()
|
2023-03-01 19:57:29 -07:00
|
|
|
encoder = command_buffer.computeCommandEncoder()
|
|
|
|
encoder.setComputePipelineState_(self.pipeline_state)
|
2023-08-16 15:43:41 -06:00
|
|
|
for i,a in enumerate(bufs):
|
|
|
|
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
|
|
|
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
|
|
|
|
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
|
2023-06-21 12:50:43 -06:00
|
|
|
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
2023-03-01 19:57:29 -07:00
|
|
|
encoder.endEncoding()
|
|
|
|
command_buffer.commit()
|
|
|
|
if wait:
|
|
|
|
command_buffer.waitUntilCompleted()
|
|
|
|
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
2023-08-01 13:10:20 -06:00
|
|
|
METAL.mtl_buffers_in_flight.append(command_buffer)
|
2023-03-01 19:57:29 -07:00
|
|
|
|
2023-08-05 12:07:04 -06:00
|
|
|
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
2023-08-16 15:43:41 -06:00
|
|
|
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
|
2023-08-05 12:07:04 -06:00
|
|
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
|
|
|
|
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
|
|
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
|
|
|
|
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(), renderer, MetalProgram, METAL.synchronize)
|