1
0
Fork 0
tinygrab/tinygrad/codegen/cstyle.py

207 lines
11 KiB
Python

from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.lazy import LazyBuffer
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
class CStyleLanguage(NamedTuple):
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_prefix: str = ""
barrier: str = ""
gid: List[str] = []
lid: List[str] = []
extra_args: List[str] = []
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype) -> str:
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
raise NotImplementedError(f"no cast for {var_dtype}")
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, "images must be float4"
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
if output_dtype.sz > 1:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
return f"{buf_name}[{idx.render(render_cl)}]"
# returns a str statement that does the store
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
}
def add_gl_dimension(args, i, var, local_size, xid):
# for M1 tensor core stuff, support > 3 dims
if i >= 2 and len(args[0]) > len(xid):
# do this on the x dim for warps
if len(local_size) == 2: local_size.append(1)
local_size[-1] *= var.max+1
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1)
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
prekernel: Set[str] = set()
kernel = []
global_size = []
local_size = []
pend_close = None
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
depth = 0
def kk(s): kernel.append(" "*depth+s)
for uop,newvar,vin,args in uops:
if uop == UOps.LOOP:
for i,var in enumerate(args[0]):
if isinstance(var, NumNode):
if args[1] == "global" and lang.gid: global_size.append(1)
if args[1] == "local" and lang.lid: local_size.append(1)
# one number, not an index
kk("{")
else:
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.ENDLOOP:
if args[1] == "local" and len(lang.lid):
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
else:
if args[1] == "global" and pend_close:
depth -= 1
kk(pend_close)
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
elif uop == UOps.WMMA:
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {vin[0].render()}; a.thread_elements()[1] = {vin[1].render()};")
kk(f"b.thread_elements()[0] = {vin[2].render()}; b.thread_elements()[1] = {vin[3].render()};")
kk(f"c.thread_elements()[0] = {vin[4].render()}; c.thread_elements()[1] = {vin[5].render()};")
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
elif uop == UOps.CONST:
assert newvar is not None
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
elif uop == UOps.ALU:
assert newvar is not None
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD:
assert newvar is not None
# valids are handled here
if args.valid.max == 0:
val = lang.render_const(0.0, newvar.dtype)
elif isinstance(bufs[args.i].realized, RawConst):
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
else:
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
kk(f"{newvar.render(True)} = {val};")
elif uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid"
# TODO: instead of dtypes.float, a base type
kk(lang.render_store(bufnames[args.i], bufs[args.i].dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, isinstance(bufs[args.i], LocalBuffer)))
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
else:
raise RuntimeError(f"failed to render {uop}")
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
return prg, global_size, local_size
class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
supports_constant_folding: bool = True
supports_float4: bool = True
supports_float4_alu: bool = True
# for renaming
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {}
def codegen(self):
self.process()
self.hand_coded_optimizations()
self.limit_global_dims(len(self.lang.gid)) # NOTE: this is optional now
self.linearize()
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)
# painfully name the function something unique
if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg]
else:
CStyleCodegen.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else ""
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
global_size[::-1], local_size[::-1],
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)