2023-08-05 12:07:04 -06:00
|
|
|
from typing import Dict, List, Optional, NamedTuple, Tuple, Union
|
2023-08-05 09:53:25 -06:00
|
|
|
import math
|
2023-08-05 12:07:04 -06:00
|
|
|
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
|
|
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
2023-08-05 09:53:25 -06:00
|
|
|
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
|
2023-07-08 16:54:58 -06:00
|
|
|
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
2023-03-20 00:43:49 -06:00
|
|
|
|
|
|
|
# div is different in cl than python
|
|
|
|
render_cl = render_python.copy()
|
|
|
|
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
|
|
|
|
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
|
|
|
|
|
|
|
class CStyleLanguage(NamedTuple):
|
2023-07-12 13:52:06 -06:00
|
|
|
size_prefix: str = "int"
|
|
|
|
generic_var_prefix: str = ""
|
2023-03-20 00:43:49 -06:00
|
|
|
kernel_prefix: str = ""
|
|
|
|
buffer_prefix: str = ""
|
|
|
|
buffer_suffix: str = ""
|
|
|
|
smem_prefix: str = ""
|
|
|
|
barrier: str = ""
|
|
|
|
gid: List[str] = []
|
|
|
|
lid: List[str] = []
|
2023-07-31 12:14:54 -06:00
|
|
|
global_max: List[int] = []
|
2023-07-31 20:18:19 -06:00
|
|
|
local_max: List[int] = []
|
2023-03-20 00:43:49 -06:00
|
|
|
extra_args: List[str] = []
|
|
|
|
float4: Optional[str] = None
|
|
|
|
half_prekernel: Optional[str] = None
|
|
|
|
uses_vload: bool = False
|
2023-07-12 13:52:06 -06:00
|
|
|
external_local_bufs: bool = False
|
2023-07-20 17:46:45 -06:00
|
|
|
uses_ptr_arithmetic: bool = False
|
2023-07-12 13:52:06 -06:00
|
|
|
code_for_op: Dict = {
|
|
|
|
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
|
|
|
UnaryOps.LOG2: lambda x: f"log2({x})",
|
|
|
|
UnaryOps.SIN: lambda x: f"sin({x})",
|
|
|
|
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
|
|
|
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
|
|
|
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
|
|
|
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
2023-07-16 01:31:55 -06:00
|
|
|
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
|
|
|
|
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
|
2023-07-12 13:52:06 -06:00
|
|
|
}
|
2023-03-20 00:43:49 -06:00
|
|
|
|
2023-07-09 10:06:00 -06:00
|
|
|
# returns a str expression of the casted xs with the given type
|
2023-07-12 13:52:06 -06:00
|
|
|
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
2023-07-09 10:06:00 -06:00
|
|
|
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
|
|
|
assert self.float4 is not None, "cast is not supported on this platform"
|
|
|
|
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
2023-07-12 11:26:38 -06:00
|
|
|
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
2023-07-09 10:06:00 -06:00
|
|
|
raise NotImplementedError(f"no cast for {var_dtype}")
|
|
|
|
|
2023-07-08 16:54:58 -06:00
|
|
|
# returns a str expression of the const with the given type
|
|
|
|
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
|
|
|
if math.isnan(x): val = "NAN"
|
|
|
|
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
2023-07-19 19:59:47 -06:00
|
|
|
else: val = f"{x}" + ("f" if isinstance(x, float) else "")
|
2023-07-09 10:06:00 -06:00
|
|
|
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
2023-07-08 16:54:58 -06:00
|
|
|
|
|
|
|
# returns a str expression of the loaded value with the output type
|
|
|
|
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
|
|
|
if isinstance(buf_dtype, ImageDType):
|
|
|
|
assert output_dtype == dtypes._float4, "images must be float4"
|
|
|
|
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
|
2023-07-12 11:26:38 -06:00
|
|
|
if self.uses_vload and buf_dtype == dtypes.float16:
|
2023-07-09 10:06:00 -06:00
|
|
|
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
|
2023-07-12 11:26:38 -06:00
|
|
|
if output_dtype.sz > 1:
|
2023-07-09 10:06:00 -06:00
|
|
|
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
|
2023-07-20 17:46:45 -06:00
|
|
|
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}]"
|
2023-07-17 17:36:33 -06:00
|
|
|
|
|
|
|
def render_local(self, name:str, size:int):
|
2023-07-12 13:52:06 -06:00
|
|
|
return self.smem_prefix + f"float {name}[{size}];"
|
2023-07-17 17:36:33 -06:00
|
|
|
|
2023-07-12 13:52:06 -06:00
|
|
|
def render_for(self, expr: str, _min:int, _max:int) -> str:
|
|
|
|
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
2023-07-17 17:36:33 -06:00
|
|
|
|
2023-07-12 13:52:06 -06:00
|
|
|
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
|
|
|
return f"({cond})?({x}):{y}"
|
|
|
|
|
2023-08-05 09:53:25 -06:00
|
|
|
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
|
2023-07-17 17:36:33 -06:00
|
|
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
|
|
|
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
|
|
|
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
2023-08-05 09:53:25 -06:00
|
|
|
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
|
2023-07-17 17:36:33 -06:00
|
|
|
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
2023-07-12 13:52:06 -06:00
|
|
|
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
2023-07-17 17:36:33 -06:00
|
|
|
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
2023-07-12 13:52:06 -06:00
|
|
|
|
|
|
|
return prg, global_size[::-1], local_size[::-1]
|
2023-07-08 16:54:58 -06:00
|
|
|
|
|
|
|
# returns a str statement that does the store
|
2023-07-12 13:52:06 -06:00
|
|
|
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
2023-07-08 16:54:58 -06:00
|
|
|
if isinstance(buf_dtype, ImageDType):
|
|
|
|
assert var_dtype == dtypes._float4, "images must be float4"
|
|
|
|
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
2023-07-12 11:26:38 -06:00
|
|
|
if self.uses_vload and buf_dtype == dtypes.float16:
|
2023-07-09 10:06:00 -06:00
|
|
|
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
|
2023-07-12 11:26:38 -06:00
|
|
|
if var_dtype.sz > 1:
|
2023-07-09 10:06:00 -06:00
|
|
|
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
2023-07-20 17:46:45 -06:00
|
|
|
return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
2023-03-20 00:43:49 -06:00
|
|
|
|
2023-07-12 13:52:06 -06:00
|
|
|
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
|
2023-07-09 10:06:00 -06:00
|
|
|
# for M1 tensor core stuff, support > 3 dims
|
|
|
|
if i >= 2 and len(args[0]) > len(xid):
|
|
|
|
# do this on the x dim for warps
|
|
|
|
if len(local_size) == 2: local_size.append(1)
|
|
|
|
local_size[-1] *= var.max+1
|
|
|
|
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
|
|
|
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
|
|
|
assert lidx.max == var.max and lidx.min == var.min
|
2023-07-21 10:55:49 -06:00
|
|
|
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
2023-07-12 11:26:38 -06:00
|
|
|
local_size.append(var.max+1)
|
2023-07-21 10:55:49 -06:00
|
|
|
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
2023-07-09 10:06:00 -06:00
|
|
|
|
2023-08-05 12:07:04 -06:00
|
|
|
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
|
2023-07-21 10:55:49 -06:00
|
|
|
global_size: List[int] = []
|
|
|
|
local_size: List[int] = []
|
|
|
|
kernel,prekernel = [],[]
|
2023-03-20 09:19:48 -06:00
|
|
|
pend_close = None
|
2023-07-17 17:36:33 -06:00
|
|
|
bufs = []
|
2023-03-20 09:19:48 -06:00
|
|
|
depth = 0
|
|
|
|
def kk(s): kernel.append(" "*depth+s)
|
|
|
|
|
2023-03-20 13:31:02 -06:00
|
|
|
for uop,newvar,vin,args in uops:
|
2023-03-20 09:19:48 -06:00
|
|
|
if uop == UOps.LOOP:
|
|
|
|
for i,var in enumerate(args[0]):
|
2023-07-21 10:55:49 -06:00
|
|
|
if args[1] == "global" and lang.gid:
|
|
|
|
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
|
|
|
elif args[1] == "local" and lang.lid:
|
|
|
|
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
2023-03-20 09:19:48 -06:00
|
|
|
else:
|
2023-08-04 17:50:19 -06:00
|
|
|
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
|
2023-07-21 10:55:49 -06:00
|
|
|
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
|
2023-03-20 09:19:48 -06:00
|
|
|
depth += 1
|
2023-06-26 16:41:23 -06:00
|
|
|
elif uop == UOps.BARRIER:
|
|
|
|
kk(lang.barrier)
|
2023-03-28 19:11:02 -06:00
|
|
|
elif uop == UOps.ENDLOOP:
|
2023-03-20 09:19:48 -06:00
|
|
|
if args[1] == "local" and len(lang.lid):
|
|
|
|
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
|
|
|
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
|
|
|
|
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
|
|
|
|
else:
|
|
|
|
if args[1] == "global" and pend_close:
|
|
|
|
depth -= 1
|
|
|
|
kk(pend_close)
|
|
|
|
pend_close = None
|
|
|
|
depth -= 1
|
|
|
|
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
2023-07-09 10:06:00 -06:00
|
|
|
elif uop == UOps.WMMA:
|
|
|
|
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
|
|
|
kk("{ simdgroup_float8x8 a,b,c;")
|
|
|
|
kk(f"a.thread_elements()[0] = {vin[0].render()}; a.thread_elements()[1] = {vin[1].render()};")
|
|
|
|
kk(f"b.thread_elements()[0] = {vin[2].render()}; b.thread_elements()[1] = {vin[3].render()};")
|
|
|
|
kk(f"c.thread_elements()[0] = {vin[4].render()}; c.thread_elements()[1] = {vin[5].render()};")
|
|
|
|
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
|
|
|
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
|
2023-03-28 19:11:02 -06:00
|
|
|
elif uop == UOps.ALU:
|
|
|
|
assert newvar is not None
|
2023-07-12 13:52:06 -06:00
|
|
|
kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
|
2023-07-19 19:59:47 -06:00
|
|
|
elif uop == UOps.LOAD:
|
|
|
|
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
2023-07-08 16:54:58 -06:00
|
|
|
# valids are handled here
|
|
|
|
if args.valid.max == 0:
|
2023-07-19 19:59:47 -06:00
|
|
|
val = lang.render_const(args.invalid_value, newvar.dtype)
|
|
|
|
elif isinstance(args, ConstOp):
|
2023-07-17 17:36:33 -06:00
|
|
|
val = lang.render_const(args.value, newvar.dtype)
|
2023-05-25 21:21:15 -06:00
|
|
|
else:
|
2023-07-19 19:59:47 -06:00
|
|
|
val = lang.render_load(newvar.dtype, args.name, args.memory_dtype, args.idx, args.local)
|
|
|
|
if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(args.invalid_value, newvar.dtype))
|
2023-07-12 13:52:06 -06:00
|
|
|
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
|
2023-07-08 16:54:58 -06:00
|
|
|
elif uop == UOps.STORE:
|
2023-07-19 19:59:47 -06:00
|
|
|
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
|
2023-07-08 16:54:58 -06:00
|
|
|
# TODO: instead of dtypes.float, a base type
|
2023-07-19 19:59:47 -06:00
|
|
|
kk(lang.render_store(args.name, args.memory_dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, args.local))
|
2023-07-09 10:06:00 -06:00
|
|
|
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
|
|
|
|
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
|
2023-03-28 19:11:02 -06:00
|
|
|
elif uop == UOps.DEFINE_LOCAL:
|
2023-07-12 13:52:06 -06:00
|
|
|
if lang.external_local_bufs:
|
|
|
|
prekernel.append(lang.render_local(args[0], args[1]))
|
|
|
|
else:
|
|
|
|
kk(lang.render_local(args[0], args[1]))
|
2023-07-17 17:36:33 -06:00
|
|
|
elif uop == UOps.DEFINE_GLOBAL:
|
|
|
|
bufs.append(args)
|
2023-03-28 19:11:02 -06:00
|
|
|
else:
|
|
|
|
raise RuntimeError(f"failed to render {uop}")
|
2023-03-20 09:19:48 -06:00
|
|
|
|
2023-08-05 09:53:25 -06:00
|
|
|
return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)
|