From e2af95c2f8a5a00fd76351d37fc6abecb2920acb Mon Sep 17 00:00:00 2001 From: Diogo Date: Sat, 5 Aug 2023 21:23:18 -0400 Subject: [PATCH] moved global_max and local_max to LinearizerOptions also added assert for max bufs (#1446) --- tinygrad/renderer/wgsl.py | 2 -- tinygrad/runtime/ops_webgpu.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index fc8f6bb42..6eb760403 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -11,8 +11,6 @@ class WGSLLanguage(CStyleLanguage): gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] size_prefix = "let" - global_max = [65535, 65535, 65535] - local_max = [256, 256, 64] barrier="workgroupBarrier();" generic_var_prefix = "var " external_local_bufs = True diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 028aaa142..c3cd95d20 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -14,6 +14,7 @@ device = get_default_device() class WebGPUProgram: def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg) def __call__(self, global_size, local_size, *bufs, wait=False): + assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)] bind_group_layout = device.create_bind_group_layout(entries=binding_layouts) @@ -36,4 +37,4 @@ class RawWebGPUBuffer(RawBufferCopyIn): def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore renderer = functools.partial(uops_to_cstyle, WGSLLanguage()) -WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False), renderer, WebGPUProgram) +WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)