2023-11-16 10:52:13 -07:00
|
|
|
from wgpu.utils.device import get_default_device
|
2023-11-30 18:07:16 -07:00
|
|
|
from tinygrad.device import Compiled, Allocator
|
2023-08-31 15:42:09 -06:00
|
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
2023-12-02 17:29:56 -07:00
|
|
|
from tinygrad.renderer.cstyle import WGSLRenderer
|
2023-11-12 12:04:20 -07:00
|
|
|
import wgpu
|
2023-07-12 13:52:06 -06:00
|
|
|
|
2023-08-17 11:33:32 -06:00
|
|
|
wgpu_device = get_default_device()
|
2023-07-12 13:52:06 -06:00
|
|
|
|
2023-12-04 22:01:04 -07:00
|
|
|
|
2023-07-12 13:52:06 -06:00
|
|
|
class WebGPUProgram:
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
This class represents a WebGPU program. It stores the name, library, and compiled shader module of a GPU program.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
name (str): The name of the GPU program.
|
|
|
|
lib (bytes): The library containing the code for the GPU program.
|
|
|
|
prg (wgpu_device.create_shader_module): The compiled shader module of the GPU program.
|
|
|
|
"""
|
|
|
|
|
2023-12-04 22:01:04 -07:00
|
|
|
def __init__(self, name: str, lib: bytes):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Constructs a WebGPUProgram object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): The name of the GPU program.
|
|
|
|
lib (bytes): The library containing the code for the GPU program.
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
This is the compiler for the GPU program.
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
self.name, self.lib, self.prg = (
|
|
|
|
name,
|
|
|
|
lib,
|
|
|
|
wgpu_device.create_shader_module(code=lib),
|
|
|
|
) # NOTE: this is the compiler
|
|
|
|
|
|
|
|
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Executes the GPU program.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
*bufs (tuple of buffers): The input buffers for the GPU program.
|
|
|
|
global_size (tuple): The size of the global workgroup.
|
|
|
|
local_size (tuple): The size of the local workgroup.
|
|
|
|
vals (tuple, optional): Additional values to pass to the GPU program. Defaults to empty tuple.
|
|
|
|
wait (bool, optional): Whether or not to wait for the execution to finish before returning. Defaults to False.
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
|
|
|
binding_layouts = [
|
|
|
|
{
|
|
|
|
"binding": i,
|
|
|
|
"visibility": wgpu.ShaderStage.COMPUTE,
|
|
|
|
"buffer": {"type": wgpu.BufferBindingType.storage},
|
|
|
|
}
|
|
|
|
for i in range(len(bufs))
|
|
|
|
]
|
|
|
|
bindings = [
|
|
|
|
{"binding": i, "resource": {"buffer": x, "offset": 0, "size": x.size}}
|
|
|
|
for i, x in enumerate(bufs)
|
|
|
|
]
|
|
|
|
bind_group_layout = wgpu_device.create_bind_group_layout(
|
|
|
|
entries=binding_layouts
|
|
|
|
)
|
|
|
|
pipeline_layout = wgpu_device.create_pipeline_layout(
|
|
|
|
bind_group_layouts=[bind_group_layout]
|
|
|
|
)
|
|
|
|
bind_group = wgpu_device.create_bind_group(
|
|
|
|
layout=bind_group_layout, entries=bindings
|
|
|
|
)
|
|
|
|
compute_pipeline = wgpu_device.create_compute_pipeline(
|
|
|
|
layout=pipeline_layout,
|
|
|
|
compute={"module": self.prg, "entry_point": self.name},
|
|
|
|
)
|
|
|
|
command_encoder = wgpu_device.create_command_encoder()
|
|
|
|
compute_pass = command_encoder.begin_compute_pass()
|
|
|
|
compute_pass.set_pipeline(compute_pipeline)
|
|
|
|
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
|
|
|
|
compute_pass.dispatch_workgroups(*global_size) # x y z
|
|
|
|
compute_pass.end()
|
|
|
|
wgpu_device.queue.submit([command_encoder.finish()])
|
|
|
|
|
2023-08-17 11:33:32 -06:00
|
|
|
|
2023-11-30 18:07:16 -07:00
|
|
|
class WebGpuAllocator(Allocator):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
WebGpuAllocator class.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
Allocator (parent class): Parent class for this class.
|
|
|
|
"""
|
|
|
|
|
2023-12-04 22:01:04 -07:00
|
|
|
def _alloc(self, size: int):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Allocate memory on the device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
size (int): Size of memory to be allocated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Memory buffer created by wgpu_device.create_buffer().
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
return wgpu_device.create_buffer(
|
|
|
|
size=size,
|
|
|
|
usage=wgpu.BufferUsage.STORAGE
|
|
|
|
| wgpu.BufferUsage.COPY_DST
|
|
|
|
| wgpu.BufferUsage.COPY_SRC,
|
|
|
|
)
|
|
|
|
|
|
|
|
def copyin(self, dest, src: memoryview):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Copy data from source to destination.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dest: Destination of the data.
|
|
|
|
src (memoryview): Source of the data.
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
wgpu_device.queue.write_buffer(dest, 0, src)
|
|
|
|
|
|
|
|
def copyout(self, dest, src: memoryview):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Copy data from source to destination.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dest: Destination of the data.
|
|
|
|
src (memoryview): Source of the data.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
This is a temporary solution and should be removed in the future.
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
|
|
|
|
|
2023-07-12 13:52:06 -06:00
|
|
|
|
2023-11-30 18:07:16 -07:00
|
|
|
class WebGpuDevice(Compiled):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
WebGpuDevice class.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
Compiled (parent class): Parent class for this class.
|
|
|
|
"""
|
|
|
|
|
2023-12-04 22:01:04 -07:00
|
|
|
def __init__(self, device: str):
|
2023-12-06 11:31:18 -07:00
|
|
|
"""
|
|
|
|
Initialize an instance of the WebGpuDevice class.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
device (str): Device identifier.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The WebGpuAllocator and LinearizerOptions classes are also initialized here.
|
|
|
|
"""
|
2023-12-04 22:01:04 -07:00
|
|
|
super().__init__(
|
|
|
|
WebGpuAllocator(),
|
|
|
|
LinearizerOptions(
|
|
|
|
device="WEBGPU",
|
|
|
|
supports_float4=False,
|
|
|
|
local_max=[256, 256, 64],
|
|
|
|
global_max=[65535, 65535, 65535],
|
|
|
|
),
|
|
|
|
WGSLRenderer,
|
|
|
|
lambda x: x,
|
|
|
|
WebGPUProgram,
|
|
|
|
)
|