1
0
Fork 0
tinygrab/tinygrad/renderer/cstyle.py

198 lines
11 KiB
Python

from typing import Dict, List, Optional, NamedTuple, Tuple, Union
import math
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"
generic_var_prefix: str = ""
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_prefix: str = ""
arg_int_prefix: str = ""
barrier: str = ""
gid: List[str] = []
lid: List[str] = []
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
external_local_bufs: bool = False
uses_ptr_arithmetic: bool = False
launch_bounds: bool = False
code_for_op: Dict = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType) -> str:
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
raise NotImplementedError(f"no cast for {var_dtype}")
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, "images must be float4"
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
if output_dtype.sz > 1:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}]"
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
return prg, global_size[::-1], local_size[::-1]
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
# for M1 tensor core stuff, support > 3 dims
if i >= 2 and len(args[0]) > len(xid):
# do this on the x dim for warps
if len(local_size) == 2: local_size.append(1)
local_size[-1] *= var.max+1
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1)
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
global_size: List[int] = []
local_size: List[int] = []
kernel,prekernel = [],[]
pend_close = None
bufs = []
depth = 0
def kk(s): kernel.append(" "*depth+s)
for uop,newvar,vin,args in uops:
if uop == UOps.LOOP:
for i,var in enumerate(args[0]):
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.ENDLOOP:
if args[1] == "local" and lang.lid:
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
else:
if args[1] == "global" and pend_close:
depth -= 1
kk(pend_close)
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
elif uop == UOps.WMMA:
if args == "METAL":
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {vin[0].render()}; a.thread_elements()[1] = {vin[1].render()};")
kk(f"b.thread_elements()[0] = {vin[2].render()}; b.thread_elements()[1] = {vin[3].render()};")
kk(f"c.thread_elements()[0] = {vin[4].render()}; c.thread_elements()[1] = {vin[5].render()};")
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
elif args == "HIP":
kk("{")
kk(f"half16 a_frag = {{ {','.join(['(half)'+x.render() for x in vin[8:8+16]])} }};")
kk(f"half16 b_frag = {{ {','.join(['(half)'+x.render() for x in vin[8+16:8+32]])} }};")
kk(f"float8 c_frag = {{ {','.join([x.render() for x in vin[:8]])} }};")
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
for i in range(8): kk(f"{vin[i].render()} = c_frag[{i}];")
kk("}")
else:
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
assert newvar is not None
kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD:
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
# valids are handled here
if isinstance(args, ConstOp):
val = lang.render_const(args.value, newvar.dtype)
else:
val = lang.render_load(newvar.dtype, args.name, args.memory_dtype, args.idx, args.local)
if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(args.invalid_value, newvar.dtype))
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
elif uop == UOps.STORE:
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
# TODO: instead of dtypes.float, a base type
kk(lang.render_store(args.name, args.memory_dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, args.local))
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
elif uop == UOps.DEFINE_LOCAL:
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))
else:
kk(lang.render_local(args[0], args[1]))
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
else:
raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)