Update triton to work in master (#517)
* Update triton to work in master * Move mem_estimate out of runnerpull/519/head
parent
5e37f084db
commit
45e847d284
|
@ -16,9 +16,6 @@ from tinygrad.shape import ShapeTracker
|
|||
from tinygrad.helpers import prod
|
||||
from tinygrad.ast import ASTKernel
|
||||
|
||||
from tinygrad.shape import View, ZeroView
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
stream = cuda.Stream()
|
||||
|
||||
class TritonASTKernel(ASTKernel):
|
||||
|
@ -31,22 +28,12 @@ class TritonASTKernel(ASTKernel):
|
|||
}
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "float('-inf')"}
|
||||
|
||||
# TODO: move to shapetracker
|
||||
def compute_buf_index_symbolic(self, st, buf_index, offset=0):
|
||||
view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset)
|
||||
idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)])
|
||||
valid = Variable.num(1)
|
||||
for v in st.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
|
||||
else: idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
|
||||
def ast_parse(self, x:Union[TritonBuffer, LazyOp], acc:str, do_reduce=False) -> str:
|
||||
if not isinstance(x, LazyOp):
|
||||
# this is a load
|
||||
buf_index = self.bufs.index(x)
|
||||
if buf_index not in self.loaded:
|
||||
idx, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index)
|
||||
idx, valid = self.sts[buf_index].expr_idxs()
|
||||
valid_expr = str(valid).replace("&&", "*1*")
|
||||
self.kernel.append(self.kernel_prefix + f" val{buf_index} = tl.where({valid_expr}, tl.load(data{buf_index} + {idx}, mask={valid_expr}), 0.0)")
|
||||
self.loaded.add(buf_index)
|
||||
|
@ -70,7 +57,7 @@ class TritonASTKernel(ASTKernel):
|
|||
self.kernel = ["@triton.jit"]
|
||||
self.kernel.append("def fxn("+','.join(f"data{i}" for i in range(len(self.bufs)))+"):")
|
||||
|
||||
self.output_shape = list(self.shapes[0][:self.first_reduce])
|
||||
self.output_shape = list(self.sts[0].shape[:self.first_reduce])
|
||||
|
||||
# copied from ops_gpu
|
||||
# TODO CUDA only supports a grid of (2^31-1, 65535, 65535), that results in invalid kernel launches for some shapes, so flattern the grid for now.
|
||||
|
@ -85,8 +72,8 @@ class TritonASTKernel(ASTKernel):
|
|||
elif len(self.output_shape) == 0: self.output_shape = [1]
|
||||
|
||||
if self.reduceop:
|
||||
full_shape = [x for x in self.shapes if x != self.shapes[0]]
|
||||
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
|
||||
full_shape = [st.shape for st in self.sts if st.shape != self.sts[0].shape]
|
||||
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
|
||||
self.kernel += [f" acc = {TritonASTKernel.start_for_op[self.reduceop.op]}"]
|
||||
self.kernel += [(" "*(i-self.first_reduce)+f" for idx{i} in range(0, {full_shape[i]}):") for i in range(self.first_reduce, self.shape_len)]
|
||||
self.kernel_prefix = " "*(self.shape_len - self.first_reduce)
|
||||
|
@ -96,7 +83,7 @@ class TritonASTKernel(ASTKernel):
|
|||
code = self.ast_parse(self.ast, "acc")
|
||||
|
||||
# store
|
||||
idx, valid = self.compute_buf_index_symbolic(self.bufs[0].st, 0)
|
||||
idx, valid = self.sts[0].expr_idxs()
|
||||
self.kernel.append(f" tl.store(data0 + {idx}, {code})")
|
||||
|
||||
# Torch inductor seems to write out files too!
|
||||
|
@ -108,8 +95,7 @@ class TritonASTKernel(ASTKernel):
|
|||
codeObject = compile(kernel, fn, "exec")
|
||||
exec(codeObject, globals())
|
||||
program = globals()['fxn']
|
||||
|
||||
mem_estimate = sum(prod(x) for x in self.shapes)
|
||||
mem_estimate = sum(prod(x._base_shape) for x in self.bufs)
|
||||
def runner(*bufs):
|
||||
GlobalCounters.global_ops += self.info.flops
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
|
@ -120,10 +106,10 @@ class TritonASTKernel(ASTKernel):
|
|||
class TritonBuffer(ExplicitExecAST):
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[TritonBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
|
||||
super().__init__(shape, hostbuf)
|
||||
if hostbuf is not None and hostbuf._buf is None: hostbuf.torch
|
||||
self._buf : Optional[TritonBuffer] = hostbuf._buf if hostbuf is not None else None
|
||||
self._buf : Optional[TritonWrapper] = hostbuf._buf if hostbuf is not None else None
|
||||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
if force_create: self.torch
|
||||
|
||||
@property
|
||||
def torch(self):
|
||||
|
|
Loading…
Reference in New Issue