296 lines
10 KiB
Python
296 lines
10 KiB
Python
from __future__ import annotations
|
|
import os, subprocess, pathlib, ctypes, tempfile, functools
|
|
import Metal, libdispatch
|
|
from typing import List, Any, Tuple, Optional
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
|
from tinygrad.helpers import prod, getenv, DEBUG, diskcache, unwrap2
|
|
from tinygrad.device import Compiled, LRUAllocator
|
|
from tinygrad.renderer.cstyle import MetalRenderer
|
|
|
|
|
|
@diskcache
|
|
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
|
|
"""
|
|
Compile Metal shader code into a byte array.
|
|
|
|
This function compiles the provided Metal shading language source code into a byte
|
|
array that represents a compiled Metal library. The compilation can be performed
|
|
either through the default Metal driver or, if the "METAL_XCODE" environment variable
|
|
is set and not empty, by invoking the "metal" command-line tool provided by Xcode.
|
|
|
|
:param prg: Metal shading language source code to compile.
|
|
:type prg: str
|
|
:param use_xcode: Whether to use the "metal" command-line tool from Xcode for
|
|
compilation instead of the default Metal driver, defaults to
|
|
bool(getenv("METAL_XCODE")).
|
|
:type use_xcode: bool, optional
|
|
:return: Compiled Metal library as a byte array.
|
|
:rtype: bytes
|
|
"""
|
|
assert (
|
|
MetalDevice.compiler_device
|
|
), "metal device creation is required for metal compile"
|
|
if use_xcode:
|
|
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
|
air = subprocess.check_output(
|
|
["xcrun", "-sdk", "macosx", "metal", "-x", "metal", "-c", "-", "-o", "-"],
|
|
input=prg.encode("utf-8"),
|
|
)
|
|
return subprocess.check_output(
|
|
["xcrun", "-sdk", "macosx", "metallib", "-", "-o", "-"], input=air
|
|
)
|
|
options = Metal.MTLCompileOptions.new()
|
|
library = unwrap2(
|
|
MetalDevice.compiler_device.newLibraryWithSource_options_error_(
|
|
prg, options, None
|
|
)
|
|
)
|
|
return library.libraryDataContents().bytes().tobytes()
|
|
|
|
|
|
class MetalProgram:
|
|
"""
|
|
MetalProgram is a class that compiles and manages Metal compute kernels.
|
|
|
|
Attributes:
|
|
device (MetalDevice): The Metal device object.
|
|
name (str): The name of the Metal kernel function.
|
|
lib (bytes): The compiled Metal library data.
|
|
"""
|
|
|
|
def __init__(self, device: MetalDevice, name: str, lib: bytes):
|
|
"""
|
|
Initializes a new MetalProgram instance with the given device, name, and library data.
|
|
|
|
Args:
|
|
device (MetalDevice): The Metal device object.
|
|
name (str): The name of the Metal kernel function.
|
|
lib (bytes): The compiled Metal library data.
|
|
"""
|
|
self.device, self.name, self.lib = device, name, lib
|
|
if DEBUG >= 6:
|
|
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
|
shader.write(lib)
|
|
shader.flush()
|
|
os.system(
|
|
f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}"
|
|
)
|
|
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
|
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
|
|
self.fxn = self.library.newFunctionWithName_(name)
|
|
self.pipeline_state = unwrap2(
|
|
self.device.device.newComputePipelineStateWithFunction_error_(
|
|
self.fxn, None
|
|
)
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
*bufs,
|
|
global_size: Tuple[int, int, int],
|
|
local_size: Tuple[int, int, int],
|
|
vals: Tuple[int, ...] = (),
|
|
wait=False,
|
|
):
|
|
"""
|
|
Executes the Metal kernel function.
|
|
|
|
Args:
|
|
*bufs (Tuple): The input buffers for the kernel function.
|
|
global_size (Tuple[int, int, int]): The global size of the threadgroups.
|
|
local_size (Tuple[int, int, int]): The local size of the threadgroups.
|
|
vals (Tuple[int, ...]): Additional integer values to pass to the kernel function. Default is an empty tuple.
|
|
wait (bool): Whether or not to wait for the execution to finish before returning. Default is False.
|
|
|
|
Returns:
|
|
float: The time taken to execute the kernel function if `wait` is True, otherwise None.
|
|
"""
|
|
assert (
|
|
prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup()
|
|
), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
|
command_buffer = self.device.mtl_queue.commandBuffer()
|
|
encoder = command_buffer.computeCommandEncoder()
|
|
encoder.setComputePipelineState_(self.pipeline_state)
|
|
for i, a in enumerate(bufs):
|
|
encoder.setBuffer_offset_atIndex_(a, 0, i)
|
|
for i, a in enumerate(vals, start=len(bufs)):
|
|
encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
|
|
encoder.dispatchThreadgroups_threadsPerThreadgroup_(
|
|
Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)
|
|
)
|
|
encoder.endEncoding()
|
|
command_buffer.commit()
|
|
if wait:
|
|
command_buffer.waitUntilCompleted()
|
|
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
|
self.device.mtl_buffers_in_flight.append(command_buffer)
|
|
|
|
|
|
class MetalAllocator(LRUAllocator):
|
|
"""
|
|
MetalAllocator class.
|
|
|
|
Attributes:
|
|
device (MetalDevice): Metal device.
|
|
"""
|
|
|
|
def __init__(self, device: MetalDevice):
|
|
"""
|
|
Initializes the MetalAllocator with a MetalDevice.
|
|
|
|
Args:
|
|
device (MetalDevice): The Metal device to use for allocation.
|
|
"""
|
|
self.device: MetalDevice = device
|
|
super().__init__()
|
|
|
|
def _alloc(self, size: int) -> Any:
|
|
"""
|
|
Allocates a new buffer with the specified size and storage mode.
|
|
|
|
Args:
|
|
size (int): The size of the buffer to allocate.
|
|
|
|
Returns:
|
|
Any: The allocated buffer.
|
|
|
|
Raises:
|
|
MemoryError: If there is not enough memory to allocate the buffer.
|
|
"""
|
|
ret = self.device.device.newBufferWithLength_options_(
|
|
size, Metal.MTLResourceStorageModeShared
|
|
)
|
|
if ret is None:
|
|
raise MemoryError(f"Metal OOM while allocating {size=}")
|
|
return ret
|
|
|
|
def transfer(self, dest: Any, src: Any, sz: int):
|
|
"""
|
|
Transfers data from one buffer to another.
|
|
|
|
Args:
|
|
dest (Any): The destination buffer.
|
|
src (Any): The source buffer.
|
|
sz (int): The size of the data to transfer.
|
|
"""
|
|
command_buffer = self.device.mtl_queue.commandBuffer()
|
|
encoder = command_buffer.blitCommandEncoder()
|
|
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(
|
|
src, 0, dest, 0, sz
|
|
)
|
|
encoder.endEncoding()
|
|
command_buffer.commit()
|
|
self.device.mtl_buffers_in_flight.append(command_buffer)
|
|
|
|
def from_buffer(self, src: memoryview) -> Optional[Any]:
|
|
"""
|
|
Creates a new Metal buffer from the specified memory view.
|
|
|
|
Args:
|
|
src (memoryview): The memory view to create the buffer from.
|
|
|
|
Returns:
|
|
Any: The created Metal buffer, or None if it could not be created.
|
|
"""
|
|
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(
|
|
src, len(src), Metal.MTLResourceStorageModeShared, None
|
|
)
|
|
if ret:
|
|
self.device.mv_in_metal.append(src)
|
|
return ret
|
|
|
|
def _free(self, opaque: Any):
|
|
"""
|
|
Releases the specified buffer.
|
|
|
|
Args:
|
|
opaque (Any): The buffer to release.
|
|
"""
|
|
opaque.release()
|
|
|
|
def as_buffer(self, src: Any) -> memoryview:
|
|
"""
|
|
Converts a Metal buffer into a memory view.
|
|
|
|
Args:
|
|
src (Any): The Metal buffer to convert.
|
|
|
|
Returns:
|
|
memoryview: The converted memory view.
|
|
"""
|
|
self.device.synchronize()
|
|
return src.contents().as_buffer(src.length())
|
|
|
|
def copyin(self, dest: Any, src: memoryview):
|
|
"""
|
|
Copies data from a memory view to a Metal buffer.
|
|
|
|
Args:
|
|
dest (Any): The destination Metal buffer.
|
|
src (memoryview): The source memory view.
|
|
"""
|
|
self.as_buffer(dest)[:] = src
|
|
|
|
def copyout(self, dest: memoryview, src: Any):
|
|
"""
|
|
Copies data from a Metal buffer to a memory view.
|
|
|
|
Args:
|
|
dest (memoryview): The destination memory view.
|
|
src (Any): The source Metal buffer.
|
|
"""
|
|
dest[:] = self.as_buffer(src)
|
|
|
|
|
|
class MetalDevice(Compiled):
|
|
"""
|
|
MetalDevice class for Metal-based computations.
|
|
|
|
Attributes:
|
|
compiler_device (Any): Default device used by the Metal compiler.
|
|
"""
|
|
|
|
compiler_device = None
|
|
|
|
def __init__(self, device: str):
|
|
"""
|
|
Initializes a new instance of the MetalDevice class.
|
|
|
|
Args:
|
|
device (str): The name of the device to use for computations.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
self.device = Metal.MTLCreateSystemDefaultDevice()
|
|
if MetalDevice.compiler_device is None:
|
|
MetalDevice.compiler_device = self.device
|
|
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
|
self.mtl_buffers_in_flight: List[Any] = []
|
|
self.mv_in_metal: List[memoryview] = []
|
|
from tinygrad.features.graph.metal import MetalGraph
|
|
|
|
super().__init__(
|
|
MetalAllocator(self),
|
|
LinearizerOptions(device="METAL"),
|
|
MetalRenderer,
|
|
compile_metal,
|
|
functools.partial(MetalProgram, self),
|
|
functools.partial(MetalGraph, self),
|
|
)
|
|
|
|
def synchronize(self):
|
|
"""
|
|
Synchronizes the current device by waiting for all command buffers to complete and clearing the memory views in use.
|
|
|
|
Args:
|
|
None
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
for cbuf in self.mtl_buffers_in_flight:
|
|
cbuf.waitUntilCompleted()
|
|
self.mv_in_metal.clear()
|
|
self.mtl_buffers_in_flight.clear()
|