1
0
Fork 0

fix multiple accumulators

This commit is contained in:
George Hotz 2023-01-30 16:22:26 -08:00
parent aea55eb196
commit e87410c531

View file

@ -272,9 +272,10 @@ class CLASTKernel(ASTKernel):
if self.group_for_reduce:
lidx, lvalid = self.sts[-1].expr_idxs()
assert str(lvalid) == "1", "local buffer must be valid"
self.kernel.append(("__shared__ " if CUDA else "__local ") + f"{accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n")
self.kernel.append(f"int mid_idx = {lidx.cl}; {self.buftokens[-1].tok}[mid_idx] = {accumulators[0].tok};\n")
self.kernel.append(f"int mid_idx = {lidx.cl};")
for i,acc in enumerate(accumulators):
self.kernel.append(("__shared__ " if CUDA else "__local ") + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}]; // second stage\n")
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
self.kernel.append("barrier(CLK_LOCAL_MEM_FENCE);\n" if not CUDA else "__syncthreads();\n")
if self.upcast_in_mid_reduce:
@ -284,12 +285,14 @@ class CLASTKernel(ASTKernel):
self.upcast()
self.kernel.append("if (mid_idx == 0) {\n")
accumulators = [Token("output", self.buftokens[0].typ)]
self.kernel.append(f"{accumulators[0].decltype()} {accumulators[0].tok} = 0.0;\n")
if self.upcast_in_mid_reduce:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)//4}; mid++) {{ {CLASTKernel.code_for_op[self.reduceop.op].replace('A', accumulators[0].tok).replace('B', 'vload4(0, &temp[mid*4])')}; }}\n")
else:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)}; mid++) {{ {CLASTKernel.code_for_op[self.reduceop.op].replace('A', accumulators[0].tok).replace('B', 'temp[mid]')}; }}\n")
new_accumulators = [Token(f"output{i}", self.buftokens[0].typ) for i in range(len(accumulators))]
for i,acc in enumerate(new_accumulators):
self.kernel.append(f"{acc.decltype()} {acc.tok} = 0.0;\n")
if self.upcast_in_mid_reduce:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)//4}; mid++) {{ {CLASTKernel.code_for_op[self.reduceop.op].replace('A', acc.tok).replace('B', f'vload4(0, &temp{i}[mid*4])')}; }}\n")
else:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)}; mid++) {{ {CLASTKernel.code_for_op[self.reduceop.op].replace('A', acc.tok).replace('B', f'temp{i}[mid]')}; }}\n")
accumulators = new_accumulators
# late ast
self.store(0, self.ast_parse(self.ast, accumulators))