From 45e847d28488b9e67449f8a932eb2b636a456947 Mon Sep 17 00:00:00 2001 From: Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Date: Wed, 1 Feb 2023 21:58:14 +0100 Subject: [PATCH] Update triton to work in master (#517) * Update triton to work in master * Move mem_estimate out of runner --- accel/triton/ops_triton.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/accel/triton/ops_triton.py b/accel/triton/ops_triton.py index d2a5733aa..a78293c48 100644 --- a/accel/triton/ops_triton.py +++ b/accel/triton/ops_triton.py @@ -16,9 +16,6 @@ from tinygrad.shape import ShapeTracker from tinygrad.helpers import prod from tinygrad.ast import ASTKernel -from tinygrad.shape import View, ZeroView -from tinygrad.shape.symbolic import Variable - stream = cuda.Stream() class TritonASTKernel(ASTKernel): @@ -31,22 +28,12 @@ class TritonASTKernel(ASTKernel): } start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "float('-inf')"} - # TODO: move to shapetracker - def compute_buf_index_symbolic(self, st, buf_index, offset=0): - view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset) - idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)]) - valid = Variable.num(1) - for v in st.views[0:-1][::-1]: - if isinstance(v, ZeroView): valid = v.expr_node(valid, idx) - else: idx = v.expr_node(idx) - return idx, valid - def ast_parse(self, x:Union[TritonBuffer, LazyOp], acc:str, do_reduce=False) -> str: if not isinstance(x, LazyOp): # this is a load buf_index = self.bufs.index(x) if buf_index not in self.loaded: - idx, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index) + idx, valid = self.sts[buf_index].expr_idxs() valid_expr = str(valid).replace("&&", "*1*") self.kernel.append(self.kernel_prefix + f" val{buf_index} = tl.where({valid_expr}, tl.load(data{buf_index} + {idx}, mask={valid_expr}), 0.0)") self.loaded.add(buf_index) @@ -70,7 +57,7 @@ class TritonASTKernel(ASTKernel): self.kernel = ["@triton.jit"] self.kernel.append("def fxn("+','.join(f"data{i}" for i in range(len(self.bufs)))+"):") - self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.output_shape = list(self.sts[0].shape[:self.first_reduce]) # copied from ops_gpu # TODO CUDA only supports a grid of (2^31-1, 65535, 65535), that results in invalid kernel launches for some shapes, so flattern the grid for now. @@ -85,8 +72,8 @@ class TritonASTKernel(ASTKernel): elif len(self.output_shape) == 0: self.output_shape = [1] if self.reduceop: - full_shape = [x for x in self.shapes if x != self.shapes[0]] - full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0] + full_shape = [st.shape for st in self.sts if st.shape != self.sts[0].shape] + full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0] self.kernel += [f" acc = {TritonASTKernel.start_for_op[self.reduceop.op]}"] self.kernel += [(" "*(i-self.first_reduce)+f" for idx{i} in range(0, {full_shape[i]}):") for i in range(self.first_reduce, self.shape_len)] self.kernel_prefix = " "*(self.shape_len - self.first_reduce) @@ -96,7 +83,7 @@ class TritonASTKernel(ASTKernel): code = self.ast_parse(self.ast, "acc") # store - idx, valid = self.compute_buf_index_symbolic(self.bufs[0].st, 0) + idx, valid = self.sts[0].expr_idxs() self.kernel.append(f" tl.store(data0 + {idx}, {code})") # Torch inductor seems to write out files too! @@ -108,8 +95,7 @@ class TritonASTKernel(ASTKernel): codeObject = compile(kernel, fn, "exec") exec(codeObject, globals()) program = globals()['fxn'] - - mem_estimate = sum(prod(x) for x in self.shapes) + mem_estimate = sum(prod(x._base_shape) for x in self.bufs) def runner(*bufs): GlobalCounters.global_ops += self.info.flops GlobalCounters.global_mem += mem_estimate @@ -120,10 +106,10 @@ class TritonASTKernel(ASTKernel): class TritonBuffer(ExplicitExecAST): def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[TritonBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False): super().__init__(shape, hostbuf) - if hostbuf is not None and hostbuf._buf is None: hostbuf.torch - self._buf : Optional[TritonBuffer] = hostbuf._buf if hostbuf is not None else None + self._buf : Optional[TritonWrapper] = hostbuf._buf if hostbuf is not None else None self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing + if force_create: self.torch @property def torch(self):