1
0
Fork 0

dtypes.float.vec(sz) (#2386)

* replace all _dtypen with dtype.vec(n)

fix: print works

* conceptul refactor of cstyle render_load logic

* linearizer GEP is explicit that its dtype is the scalar version of localtype

* vectorized global_store and load don't need a conditional
pull/2396/head
qazal 2023-11-22 20:43:14 -05:00 committed by GitHub
parent cbb8486779
commit 0eda545946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 39 deletions

View File

@ -7,7 +7,7 @@ import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
class Register(NamedTuple):
@ -17,7 +17,7 @@ class Register(NamedTuple):
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes._float4:
if self.dtype == dtypes.float.vec(4):
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
@ -40,7 +40,7 @@ class AssemblyLanguage:
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
if dtype == dtypes.float.vec(4):
for off in range(4):
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
self.cnts[(dtype, scalar)] += 1

View File

@ -62,7 +62,7 @@ class RDNACodegen(AssemblyCodegen):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes._float4]:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
align = int(arg[0][0].itemsize / 4)
@ -76,7 +76,7 @@ class RDNACodegen(AssemblyCodegen):
v_cnt += align
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
if arg[0][0] == dtypes._float4:
if arg[0][0] == dtypes.float.vec(4):
for off in range(4):
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
@ -109,7 +109,7 @@ class RDNACodegen(AssemblyCodegen):
elif uop == UOps.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if out.dtype == dtypes._float4:
if out.dtype == dtypes.float.vec(4):
for off in range(4):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
else:
@ -122,8 +122,8 @@ class RDNACodegen(AssemblyCodegen):
if arg == TernaryOps.MULACC and out == vin[2]:
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes._float4:
for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]):
if out.dtype == dtypes.float.vec(4):
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
@ -132,11 +132,11 @@ class RDNACodegen(AssemblyCodegen):
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
else:
ins.append(f'global_load_{"b128" if out.dtype == dtypes._float4 else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif uop == UOps.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes._float4 else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:

View File

@ -124,8 +124,8 @@ class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k):
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes._float4]),
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes._float4]))
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
# TODO: express opts below as auto opts

View File

@ -75,7 +75,7 @@ class Linearizer(Kernel):
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
else:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
localtype = dtypes.float32 if amt == 1 else dtypes._float4 if amt == 4 else dtypes._float2
localtype = dtypes.float32 if amt == 1 else dtypes.float.vec(amt)
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
@ -98,7 +98,7 @@ class Linearizer(Kernel):
assert buf_uop is not None, f"buffer {i} wasn't UOped"
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
else:
rendered_idx = idx.render(self.render_ops, self)
@ -107,7 +107,7 @@ class Linearizer(Kernel):
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]:
@ -132,7 +132,7 @@ class Linearizer(Kernel):
idx, valid = self.sts[i].expr_idxs(k)
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
assert valid.min == 1, "stores are always valid"
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
store_offset_new[k] = self.uop(UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens))
store_offset = store_offset_new
stores = []
@ -140,7 +140,7 @@ class Linearizer(Kernel):
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx))
else:
rendered_idx = idx.render(self.render_ops, self)
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
@ -290,10 +290,10 @@ class Linearizer(Kernel):
if self.opts.device != "HIP":
ops = tuple(op1+op2+op3)
else:
ops = (self.uop(UOps.CAST, dtypes._half16, tuple(op1)),
self.uop(UOps.CAST, dtypes._half16, tuple(op2)),
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)),
self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)),
self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
for z in range(cast(DType, ret.dtype).sz):
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
i += wmma_sz[2]

View File

@ -94,7 +94,11 @@ class DType(NamedTuple):
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-1]] if self.sz > 1 else self
# dependent typing?
class ImageDType(DType):
@ -117,7 +121,7 @@ class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4))
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
@ -134,6 +138,7 @@ class dtypes:
int8: Final[DType] = DType(1, 1, "char", np.int8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int = int32
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
@ -144,12 +149,6 @@ class dtypes:
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_half16: Final[DType] = DType(0, 2*16, "half16", None, 16)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_float8: Final[DType] = DType(4, 4*8, "float8", None, 8)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes

View File

@ -46,12 +46,7 @@ class CStyleLanguage(NamedTuple):
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._half16: return f"{{{','.join(f'(half){x}' for x in x)}}}"
if var_dtype == dtypes._float8: return f"{{{','.join(x)}}}"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})"
raise NotImplementedError(f"no cast for {var_dtype}")
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(f'(half){x}' if var_dtype.scalar() == dtypes.half else x for x in x)})"
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
@ -63,7 +58,7 @@ class CStyleLanguage(NamedTuple):
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}"
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
@ -100,7 +95,7 @@ class CStyleLanguage(NamedTuple):
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
assert var_dtype == dtypes.float.vec(4), "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16 and var_dtype != dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
@ -143,7 +138,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk("}")
elif uop == UOps.WMMA:
if args[0] == "METAL":
assert dtype == dtypes._float2, "output dtype of METAL TC is _float2"
assert dtype == dtypes.float.vec(2), "output dtype of METAL TC is _float2"
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
output = ssa(u, 'wmma')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
@ -154,7 +149,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
elif args[0] == "HIP":
assert dtype == dtypes._float8, "output dtype of HIP TC is _float8"
assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
else:
raise NotImplementedError(f"WMMA not implemented for {args}")