fix used resources in metal graph (#2604)
parent
fde44aed76
commit
19a0a839db
|
@ -27,7 +27,7 @@ class MetalGraph:
|
|||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
|
||||
read_resources, write_resources = [self.int_buf] if len(var_vals) else [], []
|
||||
all_resources = [self.int_buf] if len(var_vals) else []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
|
@ -39,8 +39,7 @@ class MetalGraph:
|
|||
for i,b in enumerate(ji.rawbufs):
|
||||
if b is not None:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
||||
if i == 0: write_resources.append(b._buf)
|
||||
else: read_resources.append(b._buf)
|
||||
all_resources.append(b._buf)
|
||||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(prg.vars):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
||||
|
@ -48,14 +47,14 @@ class MetalGraph:
|
|||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.command_buffer: Any = None
|
||||
if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
||||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
|
||||
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
|
@ -64,8 +63,7 @@ class MetalGraph:
|
|||
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
|
||||
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
|
||||
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
|
|
Loading…
Reference in New Issue