1
0
Fork 0

Update triton to work in master (#517)

* Update triton to work in master

* Move mem_estimate out of runner
pull/519/head
Martin Loretz 2023-02-01 21:58:14 +01:00 committed by GitHub
parent 5e37f084db
commit 45e847d284
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 22 deletions

View File

@ -16,9 +16,6 @@ from tinygrad.shape import ShapeTracker
from tinygrad.helpers import prod
from tinygrad.ast import ASTKernel
from tinygrad.shape import View, ZeroView
from tinygrad.shape.symbolic import Variable
stream = cuda.Stream()
class TritonASTKernel(ASTKernel):
@ -31,22 +28,12 @@ class TritonASTKernel(ASTKernel):
}
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "float('-inf')"}
# TODO: move to shapetracker
def compute_buf_index_symbolic(self, st, buf_index, offset=0):
view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset)
idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)])
valid = Variable.num(1)
for v in st.views[0:-1][::-1]:
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
else: idx = v.expr_node(idx)
return idx, valid
def ast_parse(self, x:Union[TritonBuffer, LazyOp], acc:str, do_reduce=False) -> str:
if not isinstance(x, LazyOp):
# this is a load
buf_index = self.bufs.index(x)
if buf_index not in self.loaded:
idx, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index)
idx, valid = self.sts[buf_index].expr_idxs()
valid_expr = str(valid).replace("&&", "*1*")
self.kernel.append(self.kernel_prefix + f" val{buf_index} = tl.where({valid_expr}, tl.load(data{buf_index} + {idx}, mask={valid_expr}), 0.0)")
self.loaded.add(buf_index)
@ -70,7 +57,7 @@ class TritonASTKernel(ASTKernel):
self.kernel = ["@triton.jit"]
self.kernel.append("def fxn("+','.join(f"data{i}" for i in range(len(self.bufs)))+"):")
self.output_shape = list(self.shapes[0][:self.first_reduce])
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
# copied from ops_gpu
# TODO CUDA only supports a grid of (2^31-1, 65535, 65535), that results in invalid kernel launches for some shapes, so flattern the grid for now.
@ -85,8 +72,8 @@ class TritonASTKernel(ASTKernel):
elif len(self.output_shape) == 0: self.output_shape = [1]
if self.reduceop:
full_shape = [x for x in self.shapes if x != self.shapes[0]]
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
full_shape = [st.shape for st in self.sts if st.shape != self.sts[0].shape]
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
self.kernel += [f" acc = {TritonASTKernel.start_for_op[self.reduceop.op]}"]
self.kernel += [(" "*(i-self.first_reduce)+f" for idx{i} in range(0, {full_shape[i]}):") for i in range(self.first_reduce, self.shape_len)]
self.kernel_prefix = " "*(self.shape_len - self.first_reduce)
@ -96,7 +83,7 @@ class TritonASTKernel(ASTKernel):
code = self.ast_parse(self.ast, "acc")
# store
idx, valid = self.compute_buf_index_symbolic(self.bufs[0].st, 0)
idx, valid = self.sts[0].expr_idxs()
self.kernel.append(f" tl.store(data0 + {idx}, {code})")
# Torch inductor seems to write out files too!
@ -108,8 +95,7 @@ class TritonASTKernel(ASTKernel):
codeObject = compile(kernel, fn, "exec")
exec(codeObject, globals())
program = globals()['fxn']
mem_estimate = sum(prod(x) for x in self.shapes)
mem_estimate = sum(prod(x._base_shape) for x in self.bufs)
def runner(*bufs):
GlobalCounters.global_ops += self.info.flops
GlobalCounters.global_mem += mem_estimate
@ -120,10 +106,10 @@ class TritonASTKernel(ASTKernel):
class TritonBuffer(ExplicitExecAST):
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[TritonBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
super().__init__(shape, hostbuf)
if hostbuf is not None and hostbuf._buf is None: hostbuf.torch
self._buf : Optional[TritonBuffer] = hostbuf._buf if hostbuf is not None else None
self._buf : Optional[TritonWrapper] = hostbuf._buf if hostbuf is not None else None
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
if force_create: self.torch
@property
def torch(self):