1
0
Fork 0
tinygrab/tinygrad/runtime/ops_metal.py

296 lines
10 KiB
Python

from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Any, Tuple, Optional
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, diskcache, unwrap2
from tinygrad.device import Compiled, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
@diskcache
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
"""
Compile Metal shader code into a byte array.
This function compiles the provided Metal shading language source code into a byte
array that represents a compiled Metal library. The compilation can be performed
either through the default Metal driver or, if the "METAL_XCODE" environment variable
is set and not empty, by invoking the "metal" command-line tool provided by Xcode.
:param prg: Metal shading language source code to compile.
:type prg: str
:param use_xcode: Whether to use the "metal" command-line tool from Xcode for
compilation instead of the default Metal driver, defaults to
bool(getenv("METAL_XCODE")).
:type use_xcode: bool, optional
:return: Compiled Metal library as a byte array.
:rtype: bytes
"""
assert (
MetalDevice.compiler_device
), "metal device creation is required for metal compile"
if use_xcode:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(
["xcrun", "-sdk", "macosx", "metal", "-x", "metal", "-c", "-", "-o", "-"],
input=prg.encode("utf-8"),
)
return subprocess.check_output(
["xcrun", "-sdk", "macosx", "metallib", "-", "-o", "-"], input=air
)
options = Metal.MTLCompileOptions.new()
library = unwrap2(
MetalDevice.compiler_device.newLibraryWithSource_options_error_(
prg, options, None
)
)
return library.libraryDataContents().bytes().tobytes()
class MetalProgram:
"""
MetalProgram is a class that compiles and manages Metal compute kernels.
Attributes:
device (MetalDevice): The Metal device object.
name (str): The name of the Metal kernel function.
lib (bytes): The compiled Metal library data.
"""
def __init__(self, device: MetalDevice, name: str, lib: bytes):
"""
Initializes a new MetalProgram instance with the given device, name, and library data.
Args:
device (MetalDevice): The Metal device object.
name (str): The name of the Metal kernel function.
lib (bytes): The compiled Metal library data.
"""
self.device, self.name, self.lib = device, name, lib
if DEBUG >= 6:
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
os.system(
f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}"
)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
self.pipeline_state = unwrap2(
self.device.device.newComputePipelineStateWithFunction_error_(
self.fxn, None
)
)
def __call__(
self,
*bufs,
global_size: Tuple[int, int, int],
local_size: Tuple[int, int, int],
vals: Tuple[int, ...] = (),
wait=False,
):
"""
Executes the Metal kernel function.
Args:
*bufs (Tuple): The input buffers for the kernel function.
global_size (Tuple[int, int, int]): The global size of the threadgroups.
local_size (Tuple[int, int, int]): The local size of the threadgroups.
vals (Tuple[int, ...]): Additional integer values to pass to the kernel function. Default is an empty tuple.
wait (bool): Whether or not to wait for the execution to finish before returning. Default is False.
Returns:
float: The time taken to execute the kernel function if `wait` is True, otherwise None.
"""
assert (
prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup()
), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i, a in enumerate(bufs):
encoder.setBuffer_offset_atIndex_(a, 0, i)
for i, a in enumerate(vals, start=len(bufs)):
encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(
Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)
)
encoder.endEncoding()
command_buffer.commit()
if wait:
command_buffer.waitUntilCompleted()
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
self.device.mtl_buffers_in_flight.append(command_buffer)
class MetalAllocator(LRUAllocator):
"""
MetalAllocator class.
Attributes:
device (MetalDevice): Metal device.
"""
def __init__(self, device: MetalDevice):
"""
Initializes the MetalAllocator with a MetalDevice.
Args:
device (MetalDevice): The Metal device to use for allocation.
"""
self.device: MetalDevice = device
super().__init__()
def _alloc(self, size: int) -> Any:
"""
Allocates a new buffer with the specified size and storage mode.
Args:
size (int): The size of the buffer to allocate.
Returns:
Any: The allocated buffer.
Raises:
MemoryError: If there is not enough memory to allocate the buffer.
"""
ret = self.device.device.newBufferWithLength_options_(
size, Metal.MTLResourceStorageModeShared
)
if ret is None:
raise MemoryError(f"Metal OOM while allocating {size=}")
return ret
def transfer(self, dest: Any, src: Any, sz: int):
"""
Transfers data from one buffer to another.
Args:
dest (Any): The destination buffer.
src (Any): The source buffer.
sz (int): The size of the data to transfer.
"""
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(
src, 0, dest, 0, sz
)
encoder.endEncoding()
command_buffer.commit()
self.device.mtl_buffers_in_flight.append(command_buffer)
def from_buffer(self, src: memoryview) -> Optional[Any]:
"""
Creates a new Metal buffer from the specified memory view.
Args:
src (memoryview): The memory view to create the buffer from.
Returns:
Any: The created Metal buffer, or None if it could not be created.
"""
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(
src, len(src), Metal.MTLResourceStorageModeShared, None
)
if ret:
self.device.mv_in_metal.append(src)
return ret
def _free(self, opaque: Any):
"""
Releases the specified buffer.
Args:
opaque (Any): The buffer to release.
"""
opaque.release()
def as_buffer(self, src: Any) -> memoryview:
"""
Converts a Metal buffer into a memory view.
Args:
src (Any): The Metal buffer to convert.
Returns:
memoryview: The converted memory view.
"""
self.device.synchronize()
return src.contents().as_buffer(src.length())
def copyin(self, dest: Any, src: memoryview):
"""
Copies data from a memory view to a Metal buffer.
Args:
dest (Any): The destination Metal buffer.
src (memoryview): The source memory view.
"""
self.as_buffer(dest)[:] = src
def copyout(self, dest: memoryview, src: Any):
"""
Copies data from a Metal buffer to a memory view.
Args:
dest (memoryview): The destination memory view.
src (Any): The source Metal buffer.
"""
dest[:] = self.as_buffer(src)
class MetalDevice(Compiled):
"""
MetalDevice class for Metal-based computations.
Attributes:
compiler_device (Any): Default device used by the Metal compiler.
"""
compiler_device = None
def __init__(self, device: str):
"""
Initializes a new instance of the MetalDevice class.
Args:
device (str): The name of the device to use for computations.
Returns:
None
"""
self.device = Metal.MTLCreateSystemDefaultDevice()
if MetalDevice.compiler_device is None:
MetalDevice.compiler_device = self.device
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
from tinygrad.features.graph.metal import MetalGraph
super().__init__(
MetalAllocator(self),
LinearizerOptions(device="METAL"),
MetalRenderer,
compile_metal,
functools.partial(MetalProgram, self),
functools.partial(MetalGraph, self),
)
def synchronize(self):
"""
Synchronizes the current device by waiting for all command buffers to complete and clearing the memory views in use.
Args:
None
Returns:
None
"""
for cbuf in self.mtl_buffers_in_flight:
cbuf.waitUntilCompleted()
self.mv_in_metal.clear()
self.mtl_buffers_in_flight.clear()