fix metal graph with var_vals (#2583)
parent
f180cac8f0
commit
88a5c368d4
|
@ -27,7 +27,7 @@ class MetalGraph:
|
|||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
|
||||
read_resources, write_resources = [], []
|
||||
read_resources, write_resources = [self.int_buf] if len(var_vals) else [], []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
|
@ -64,9 +64,9 @@ class MetalGraph:
|
|||
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
||||
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
|
||||
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
self.command_buffer = command_buffer
|
||||
|
|
Loading…
Reference in New Issue