1
0
Fork 0
tinygrab/extra/assembly/assembly_ptx.py

217 lines
8.2 KiB
Python

from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {
dtypes.float32: "f32",
dtypes.float16: "f16",
dtypes.int64: "s64",
dtypes.int32: "s32",
dtypes.int8: "s8",
dtypes.bool: "pred",
dtypes.uint64: "u64",
dtypes.uint32: "u32",
dtypes.uint16: "u16",
dtypes.uint8: "u8",
"bits16": "b16",
dtypes.float64: "f64",
}
def float_to_hex(x):
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
def ptx_needs_cast(dest_dtype, src_dtype):
return (
dtypes.is_float(dest_dtype)
and dtypes.is_int(src_dtype)
or dtypes.is_int(dest_dtype)
and dtypes.is_float(src_dtype)
or (
dtypes.is_float(src_dtype)
and dtypes.is_float(dest_dtype)
and dest_dtype.itemsize != src_dtype.itemsize
)
)
def render_cast(ins, inp, out):
if inp.dtype == dtypes.bool and (
dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)
):
ins.append(
f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};"
)
elif out.dtype == dtypes.bool:
if inp.dtype == dtypes.bool:
ins.append(f"mov.pred {out}, {inp};")
else:
ins.append(
f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};"
)
else:
round_mod = (
".rzi"
if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype)
else ".rz"
if dtypes.is_float(out.dtype)
and (
dtypes.is_int(inp.dtype)
or dtypes.is_float(inp.dtype)
and inp.dtype.itemsize > out.dtype.itemsize
)
else ""
)
ins.append(
f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};"
)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage(AssemblyLanguage):
supports_constant_folding: bool = True
def specialize_to_ptx(lang, function_name):
param_cnt = 0
ins = []
alu = {
BinaryOps.ADD: "add",
BinaryOps.SUB: "sub",
BinaryOps.MUL: "mul",
BinaryOps.DIV: "div",
BinaryOps.MAX: "max",
BinaryOps.MOD: "rem",
BinaryOps.CMPLT: "setp.lt",
UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov",
UnaryOps.NEG: "neg",
UnaryOps.SIN: "sin.approx",
UnaryOps.LOG2: "lg2.approx",
UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn",
TernaryOps.WHERE: "selp",
}
for uop, out, vin, arg in lang.ins:
if uop == UOps.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == UOps.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
if arg.startswith("data"):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith("gid"):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith("lid"):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
if arg == TernaryOps.WHERE:
if vin[0].dtype == dtypes.bool:
reg = vin[0]
else:
reg = lang.newreg((vin[0], "bool"), dtypes.bool)
ins.append(
f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};"
)
vin = vin[1:] + [reg]
ins.append(
f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};"
)
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
ins.append(
f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};"
)
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
dt = (
("u16", dtypes.uint16)
if arg[2] == dtypes.bool == out.dtype
else ("u8", dtypes.uint8)
if arg[2] == dtypes.bool
else ("b16", dtypes.float16)
if arg[2] == dtypes.half
else (dtype_to_nvtype[arg[2]], arg[2])
)
reg = lang.newreg((out, dt[0]), dtype=dt[1])
ins.append(
f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
)
render_cast(ins, reg, out)
else:
ins.append(
f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
)
elif uop == UOps.STORE:
if (
ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype)
or arg[2] == dtypes.bool
):
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool)
render_cast(ins, vin[1], prereg)
else:
prereg = vin[1]
reg = lang.newreg(
(prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]),
dtype=dtypes.uint16
if arg[2] == dtypes.bool
else dtypes.float
if arg[2] is None
else arg[2],
)
render_cast(ins, prereg, reg)
ins.append(
f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};"
)
else:
ins.append(
f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};"
)
elif uop == UOps.CAST:
render_cast(ins, vin[0], out)
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [
".version 7.8",
".target " + arch(),
".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{",
]
for arg in [
(dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items()
]:
ins_prefix.append(
f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",
)
ins = ins_prefix + ins
ins += ["ret;", "}"]
return "\n".join(ins)
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
lang = PTXLanguage()
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
return (
specialize_to_ptx(lang, function_name),
global_size[::-1],
local_size[::-1],
True,
)