dtypes.float.vec(sz) (#2386)
* replace all _dtypen with dtype.vec(n) fix: print works * conceptul refactor of cstyle render_load logic * linearizer GEP is explicit that its dtype is the scalar version of localtype * vectorized global_store and load don't need a conditionalpull/2396/head
parent
cbb8486779
commit
0eda545946
|
@ -7,7 +7,7 @@ import functools
|
|||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
|
||||
class Register(NamedTuple):
|
||||
|
@ -17,7 +17,7 @@ class Register(NamedTuple):
|
|||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes._float4:
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
|
@ -40,7 +40,7 @@ class AssemblyLanguage:
|
|||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes._float4:
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
|
|
|
@ -62,7 +62,7 @@ class RDNACodegen(AssemblyCodegen):
|
|||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes._float4]:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
|
@ -76,7 +76,7 @@ class RDNACodegen(AssemblyCodegen):
|
|||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes._float4:
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
|
@ -109,7 +109,7 @@ class RDNACodegen(AssemblyCodegen):
|
|||
elif uop == UOps.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes._float4:
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
|
@ -122,8 +122,8 @@ class RDNACodegen(AssemblyCodegen):
|
|||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes._float4:
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
|
@ -132,11 +132,11 @@ class RDNACodegen(AssemblyCodegen):
|
|||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes._float4 else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes._float4 else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
|
|
|
@ -124,8 +124,8 @@ class TestFloat4(unittest.TestCase):
|
|||
|
||||
@staticmethod
|
||||
def count_float4(k):
|
||||
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes._float4]),
|
||||
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes._float4]))
|
||||
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
|
||||
|
||||
# TODO: express opts below as auto opts
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class Linearizer(Kernel):
|
|||
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
|
||||
else:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
|
||||
localtype = dtypes.float32 if amt == 1 else dtypes._float4 if amt == 4 else dtypes._float2
|
||||
localtype = dtypes.float32 if amt == 1 else dtypes.float.vec(amt)
|
||||
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
|
@ -98,7 +98,7 @@ class Linearizer(Kernel):
|
|||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
|
||||
|
@ -107,7 +107,7 @@ class Linearizer(Kernel):
|
|||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
|
||||
|
@ -132,7 +132,7 @@ class Linearizer(Kernel):
|
|||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens))
|
||||
store_offset = store_offset_new
|
||||
|
||||
stores = []
|
||||
|
@ -140,7 +140,7 @@ class Linearizer(Kernel):
|
|||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
|
@ -290,10 +290,10 @@ class Linearizer(Kernel):
|
|||
if self.opts.device != "HIP":
|
||||
ops = tuple(op1+op2+op3)
|
||||
else:
|
||||
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
|
||||
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
|
||||
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
|
||||
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
|
||||
self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
|
||||
self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
|
||||
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
for z in range(cast(DType, ret.dtype).sz):
|
||||
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
||||
i += wmma_sz[2]
|
||||
|
|
|
@ -94,7 +94,11 @@ class DType(NamedTuple):
|
|||
name: str
|
||||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
||||
sz: int = 1
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
|
||||
def vec(self, sz:int):
|
||||
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
||||
return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz)
|
||||
def scalar(self): return DTYPES_DICT[self.name[:-1]] if self.sz > 1 else self
|
||||
|
||||
# dependent typing?
|
||||
class ImageDType(DType):
|
||||
|
@ -117,7 +121,7 @@ class dtypes:
|
|||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4))
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
|
@ -134,6 +138,7 @@ class dtypes:
|
|||
int8: Final[DType] = DType(1, 1, "char", np.int8)
|
||||
int16: Final[DType] = DType(3, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(5, 4, "int", np.int32)
|
||||
int = int32
|
||||
int64: Final[DType] = DType(7, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
|
||||
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
|
||||
|
@ -144,12 +149,6 @@ class dtypes:
|
|||
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
|
|
|
@ -46,12 +46,7 @@ class CStyleLanguage(NamedTuple):
|
|||
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
|
||||
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
|
||||
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
|
||||
raise NotImplementedError(f"no cast for {var_dtype}")
|
||||
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(f'(half){x}' if var_dtype.scalar() == dtypes.half else x for x in x)})"
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
|
@ -63,7 +58,7 @@ class CStyleLanguage(NamedTuple):
|
|||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}"
|
||||
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
|
@ -100,7 +95,7 @@ class CStyleLanguage(NamedTuple):
|
|||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
assert var_dtype == dtypes.float.vec(4), "images must be float4"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
|
@ -143,7 +138,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
kk("}")
|
||||
elif uop == UOps.WMMA:
|
||||
if args[0] == "METAL":
|
||||
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
|
||||
assert dtype == dtypes.float.vec(2), "output dtype of METAL TC is _float2"
|
||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||
output = ssa(u, 'wmma')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
|
||||
|
@ -154,7 +149,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
||||
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
|
||||
elif args[0] == "HIP":
|
||||
assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
|
||||
assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
|
|
Loading…
Reference in New Issue