207 lines
11 KiB
Python
207 lines
11 KiB
Python
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
|
import math, collections
|
|
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
|
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
|
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod
|
|
from tinygrad.runtime.lib import RawConst
|
|
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
|
from tinygrad.lazy import LazyBuffer
|
|
|
|
# div is different in cl than python
|
|
render_cl = render_python.copy()
|
|
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
|
|
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
|
|
|
class CStyleLanguage(NamedTuple):
|
|
kernel_prefix: str = ""
|
|
buffer_prefix: str = ""
|
|
buffer_suffix: str = ""
|
|
smem_prefix: str = ""
|
|
barrier: str = ""
|
|
gid: List[str] = []
|
|
lid: List[str] = []
|
|
extra_args: List[str] = []
|
|
float4: Optional[str] = None
|
|
half_prekernel: Optional[str] = None
|
|
uses_vload: bool = False
|
|
|
|
# returns a str expression of the casted xs with the given type
|
|
def render_cast(self, x:List[str], var_dtype) -> str:
|
|
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
|
assert self.float4 is not None, "cast is not supported on this platform"
|
|
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
|
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
|
raise NotImplementedError(f"no cast for {var_dtype}")
|
|
|
|
# returns a str expression of the const with the given type
|
|
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
|
if math.isnan(x): val = "NAN"
|
|
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
|
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
|
|
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
|
|
|
# returns a str expression of the loaded value with the output type
|
|
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
|
if isinstance(buf_dtype, ImageDType):
|
|
assert output_dtype == dtypes._float4, "images must be float4"
|
|
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
|
|
if self.uses_vload and buf_dtype == dtypes.float16:
|
|
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
|
|
if output_dtype.sz > 1:
|
|
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
|
|
return f"{buf_name}[{idx.render(render_cl)}]"
|
|
|
|
# returns a str statement that does the store
|
|
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
|
|
if isinstance(buf_dtype, ImageDType):
|
|
assert var_dtype == dtypes._float4, "images must be float4"
|
|
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
|
if self.uses_vload and buf_dtype == dtypes.float16:
|
|
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
|
|
if var_dtype.sz > 1:
|
|
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
|
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
|
|
|
code_for_op: Final[Dict[Op, Callable]] = {
|
|
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
|
UnaryOps.LOG2: lambda x: f"log2({x})",
|
|
UnaryOps.SIN: lambda x: f"sin({x})",
|
|
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
|
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
|
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
|
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
|
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
|
|
}
|
|
|
|
def add_gl_dimension(args, i, var, local_size, xid):
|
|
# for M1 tensor core stuff, support > 3 dims
|
|
if i >= 2 and len(args[0]) > len(xid):
|
|
# do this on the x dim for warps
|
|
if len(local_size) == 2: local_size.append(1)
|
|
local_size[-1] *= var.max+1
|
|
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
|
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
|
assert lidx.max == var.max and lidx.min == var.min
|
|
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
|
local_size.append(var.max+1)
|
|
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
|
|
|
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
|
prekernel: Set[str] = set()
|
|
kernel = []
|
|
global_size = []
|
|
local_size = []
|
|
pend_close = None
|
|
|
|
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
|
|
|
|
depth = 0
|
|
def kk(s): kernel.append(" "*depth+s)
|
|
|
|
for uop,newvar,vin,args in uops:
|
|
if uop == UOps.LOOP:
|
|
for i,var in enumerate(args[0]):
|
|
if isinstance(var, NumNode):
|
|
if args[1] == "global" and lang.gid: global_size.append(1)
|
|
if args[1] == "local" and lang.lid: local_size.append(1)
|
|
# one number, not an index
|
|
kk("{")
|
|
else:
|
|
if args[1] == "global" and lang.gid:
|
|
kk(add_gl_dimension(args, i, var, global_size, lang.gid))
|
|
elif args[1] == "local" and lang.lid:
|
|
kk(add_gl_dimension(args, i, var, local_size, lang.lid))
|
|
else:
|
|
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
|
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
|
|
depth += 1
|
|
elif uop == UOps.BARRIER:
|
|
kk(lang.barrier)
|
|
elif uop == UOps.ENDLOOP:
|
|
if args[1] == "local" and len(lang.lid):
|
|
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
|
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
|
|
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
|
|
else:
|
|
if args[1] == "global" and pend_close:
|
|
depth -= 1
|
|
kk(pend_close)
|
|
pend_close = None
|
|
depth -= 1
|
|
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
|
elif uop == UOps.WMMA:
|
|
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
|
kk("{ simdgroup_float8x8 a,b,c;")
|
|
kk(f"a.thread_elements()[0] = {vin[0].render()}; a.thread_elements()[1] = {vin[1].render()};")
|
|
kk(f"b.thread_elements()[0] = {vin[2].render()}; b.thread_elements()[1] = {vin[3].render()};")
|
|
kk(f"c.thread_elements()[0] = {vin[4].render()}; c.thread_elements()[1] = {vin[5].render()};")
|
|
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
|
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
|
|
elif uop == UOps.CONST:
|
|
assert newvar is not None
|
|
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
|
|
elif uop == UOps.ALU:
|
|
assert newvar is not None
|
|
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
|
|
elif uop == UOps.LOAD:
|
|
assert newvar is not None
|
|
# valids are handled here
|
|
if args.valid.max == 0:
|
|
val = lang.render_const(0.0, newvar.dtype)
|
|
elif isinstance(bufs[args.i].realized, RawConst):
|
|
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
|
|
else:
|
|
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
|
|
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
|
|
kk(f"{newvar.render(True)} = {val};")
|
|
elif uop == UOps.STORE:
|
|
assert args.valid.min == 1, "store must be valid"
|
|
# TODO: instead of dtypes.float, a base type
|
|
kk(lang.render_store(bufnames[args.i], bufs[args.i].dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, isinstance(bufs[args.i], LocalBuffer)))
|
|
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
|
|
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
|
|
elif uop == UOps.DEFINE_LOCAL:
|
|
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
|
else:
|
|
raise RuntimeError(f"failed to render {uop}")
|
|
|
|
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
|
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
|
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
|
|
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
|
|
prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
|
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
|
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
|
|
|
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
|
|
return prg, global_size, local_size
|
|
|
|
class CStyleCodegen(Linearizer):
|
|
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
|
|
supports_constant_folding: bool = True
|
|
supports_float4: bool = True
|
|
supports_float4_alu: bool = True
|
|
|
|
# for renaming
|
|
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
|
|
kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {}
|
|
|
|
def codegen(self):
|
|
self.process()
|
|
self.hand_coded_optimizations()
|
|
self.limit_global_dims(len(self.lang.gid)) # NOTE: this is optional now
|
|
self.linearize()
|
|
|
|
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)
|
|
|
|
# painfully name the function something unique
|
|
if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg]
|
|
else:
|
|
CStyleCodegen.kernel_cnt[self.function_name] += 1
|
|
suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else ""
|
|
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
|
|
|
|
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
|
|
global_size[::-1], local_size[::-1],
|
|
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)
|