1
0
Fork 0

fix metal graph with var_vals (#2583)

pull/2590/head
nimlgen 2023-12-03 20:24:36 +03:00 committed by GitHub
parent f180cac8f0
commit 88a5c368d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -27,7 +27,7 @@ class MetalGraph:
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
read_resources, write_resources = [], []
read_resources, write_resources = [self.int_buf] if len(var_vals) else [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
descriptor = Metal.MTLComputePipelineDescriptor.new()
@ -64,9 +64,9 @@ class MetalGraph:
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer