1
0
Fork 0
tinygrab/tinygrad/runtime/ops_webgpu.py

162 lines
5.1 KiB
Python

from wgpu.utils.device import get_default_device
from tinygrad.device import Compiled, Allocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import WGSLRenderer
import wgpu
wgpu_device = get_default_device()
class WebGPUProgram:
"""
This class represents a WebGPU program. It stores the name, library, and compiled shader module of a GPU program.
Attributes:
name (str): The name of the GPU program.
lib (bytes): The library containing the code for the GPU program.
prg (wgpu_device.create_shader_module): The compiled shader module of the GPU program.
"""
def __init__(self, name: str, lib: bytes):
"""
Constructs a WebGPUProgram object.
Args:
name (str): The name of the GPU program.
lib (bytes): The library containing the code for the GPU program.
Notes:
This is the compiler for the GPU program.
"""
self.name, self.lib, self.prg = (
name,
lib,
wgpu_device.create_shader_module(code=lib),
) # NOTE: this is the compiler
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False):
"""
Executes the GPU program.
Args:
*bufs (tuple of buffers): The input buffers for the GPU program.
global_size (tuple): The size of the global workgroup.
local_size (tuple): The size of the local workgroup.
vals (tuple, optional): Additional values to pass to the GPU program. Defaults to empty tuple.
wait (bool, optional): Whether or not to wait for the execution to finish before returning. Defaults to False.
"""
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [
{
"binding": i,
"visibility": wgpu.ShaderStage.COMPUTE,
"buffer": {"type": wgpu.BufferBindingType.storage},
}
for i in range(len(bufs))
]
bindings = [
{"binding": i, "resource": {"buffer": x, "offset": 0, "size": x.size}}
for i, x in enumerate(bufs)
]
bind_group_layout = wgpu_device.create_bind_group_layout(
entries=binding_layouts
)
pipeline_layout = wgpu_device.create_pipeline_layout(
bind_group_layouts=[bind_group_layout]
)
bind_group = wgpu_device.create_bind_group(
layout=bind_group_layout, entries=bindings
)
compute_pipeline = wgpu_device.create_compute_pipeline(
layout=pipeline_layout,
compute={"module": self.prg, "entry_point": self.name},
)
command_encoder = wgpu_device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()
compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
compute_pass.dispatch_workgroups(*global_size) # x y z
compute_pass.end()
wgpu_device.queue.submit([command_encoder.finish()])
class WebGpuAllocator(Allocator):
"""
WebGpuAllocator class.
Attributes:
Allocator (parent class): Parent class for this class.
"""
def _alloc(self, size: int):
"""
Allocate memory on the device.
Args:
size (int): Size of memory to be allocated.
Returns:
Memory buffer created by wgpu_device.create_buffer().
"""
return wgpu_device.create_buffer(
size=size,
usage=wgpu.BufferUsage.STORAGE
| wgpu.BufferUsage.COPY_DST
| wgpu.BufferUsage.COPY_SRC,
)
def copyin(self, dest, src: memoryview):
"""
Copy data from source to destination.
Args:
dest: Destination of the data.
src (memoryview): Source of the data.
"""
wgpu_device.queue.write_buffer(dest, 0, src)
def copyout(self, dest, src: memoryview):
"""
Copy data from source to destination.
Args:
dest: Destination of the data.
src (memoryview): Source of the data.
Note:
This is a temporary solution and should be removed in the future.
"""
dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
class WebGpuDevice(Compiled):
"""
WebGpuDevice class.
Attributes:
Compiled (parent class): Parent class for this class.
"""
def __init__(self, device: str):
"""
Initialize an instance of the WebGpuDevice class.
Args:
device (str): Device identifier.
Note:
The WebGpuAllocator and LinearizerOptions classes are also initialized here.
"""
super().__init__(
WebGpuAllocator(),
LinearizerOptions(
device="WEBGPU",
supports_float4=False,
local_max=[256, 256, 64],
global_max=[65535, 65535, 65535],
),
WGSLRenderer,
lambda x: x,
WebGPUProgram,
)