good changes from the M1 Tensor Core project (#730)
* good changes * working except llvm * llvm types * nice acc * archprobe * lang.float4 * use self.acc for late acc * fix store bugpull/733/head
parent
156640e90d
commit
20894991ed
|
@ -0,0 +1,63 @@
|
|||
# copying the kernels from https://github.com/microsoft/ArchProbe into Python
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer
|
||||
from tinygrad.helpers import dtypes
|
||||
from tqdm import trange, tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
def reg_count(nthread, ngrp, nreg):
|
||||
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
|
||||
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||
prg = f"""__kernel void reg_count(
|
||||
__global float* out_buf,
|
||||
__private const int niter
|
||||
) {{
|
||||
{reg_declr}
|
||||
int i = 0;
|
||||
for (; i < niter; ++i) {{
|
||||
{reg_comp}
|
||||
}}
|
||||
i = i >> 31;
|
||||
{reg_reduce}
|
||||
}}"""
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
||||
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 10, wait=True) for _ in range(10)])
|
||||
|
||||
"""
|
||||
print("probing registers")
|
||||
pts = [(nreg, reg_count(1, 1, nreg)) for nreg in trange(1, 257)] # archprobe goes to 512
|
||||
plt.plot(*zip(*pts))
|
||||
plt.show()
|
||||
"""
|
||||
|
||||
def buf_cache_hierarchy_pchase(ndata, stride=1):
|
||||
NCOMP = 16 # 64 byte is under the 128 byte cache line
|
||||
print("probe", ndata*NCOMP*4)
|
||||
prg = """__kernel void buf_cache_hierarchy_pchase(
|
||||
__global int16* src,
|
||||
__global int* dst,
|
||||
const int niter
|
||||
) {
|
||||
int idx = 0;
|
||||
for (int i = 0; i < niter; ++i) {
|
||||
idx = src[idx].x;
|
||||
}
|
||||
*dst = idx;
|
||||
}"""
|
||||
idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32)
|
||||
for i in range(ndata):
|
||||
idx_buf[i*NCOMP] = (i + stride) % ndata
|
||||
in_buf = CLBuffer.fromCPU(idx_buf)
|
||||
out_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
||||
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, ndata*4, wait=True) for _ in range(5)])
|
||||
|
||||
# 768 kb is real
|
||||
print("probing cache size")
|
||||
base = buf_cache_hierarchy_pchase(1, 191)
|
||||
szs = list(range(128, 1024, 128)) + list(range(1024, 16*1024, 1024)) + list(range(16*1024, int(1.5*1024*1024), 16*1024)) #+ list(range(2*1024*1024, 20*1024*1024, 1024*1024))
|
||||
pts = [(ndata, (buf_cache_hierarchy_pchase(ndata//64, 136329190282766681843115968953)-base)/ndata) for ndata in tqdm(szs)]
|
||||
plt.plot(*zip(*pts))
|
||||
plt.show()
|
|
@ -0,0 +1,26 @@
|
|||
import unittest
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
class TestCopy(unittest.TestCase):
|
||||
def test_add1(self):
|
||||
pts = []
|
||||
for i in range(16384, 16384*256, 16384):
|
||||
t = Tensor.randn(i).realize()
|
||||
GlobalCounters.cache = []
|
||||
t.assign(t+1).realize()
|
||||
fxn, args = GlobalCounters.cache[0]
|
||||
GlobalCounters.reset()
|
||||
def run(): return fxn(args, force_wait=True)
|
||||
ct = min([run() for _ in range(10)])
|
||||
mb = prod(t.shape)*t.dtype.itemsize*2*1e-6
|
||||
print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s")
|
||||
pts.append((mb, mb/ct))
|
||||
from matplotlib import pyplot as plt
|
||||
plt.plot([x[0] for x in pts], [x[1] for x in pts])
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -44,26 +44,31 @@ def helper_test_speed(f1, *args):
|
|||
global save_ops, save_mem
|
||||
ets = []
|
||||
ret = None
|
||||
for _ in range(CNT):
|
||||
cache_defeat = np.zeros((2048,2048))
|
||||
for i in range(CNT):
|
||||
del ret
|
||||
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
|
||||
|
||||
# operation cache defeats
|
||||
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
|
||||
|
||||
# force syncing
|
||||
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
|
||||
|
||||
# clear 32MB global memory cache (CPU and global memory only)
|
||||
cache_defeat += 1
|
||||
|
||||
# manual pre sync
|
||||
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
|
||||
else: sync()
|
||||
|
||||
GlobalCounters.global_ops = 0
|
||||
GlobalCounters.global_mem = 0
|
||||
if DEBUG >= 4: print("benchmark start")
|
||||
st = time.monotonic()
|
||||
st = time.perf_counter()
|
||||
ret = f1(*args)
|
||||
if isinstance(ret, Tensor):
|
||||
ret.realize()
|
||||
Device[ret.device].synchronize()
|
||||
else:
|
||||
sync()
|
||||
et = (time.monotonic() - st) * 1000
|
||||
ets.append(et)
|
||||
if DEBUG >= 4: print("benchmark stop")
|
||||
if isinstance(ret, Tensor): Device[ret.device].synchronize()
|
||||
else: sync()
|
||||
et = (time.perf_counter() - st) * 1000
|
||||
if i >= 1: ets.append(et) # not the first run / one used for OPTLOCAL
|
||||
if GlobalCounters.global_ops:
|
||||
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||
return ret.cpu().numpy(), np.min(ets)
|
||||
|
@ -173,6 +178,10 @@ class TestSpeed(unittest.TestCase):
|
|||
def f(a, b): return a @ b
|
||||
helper_test_generic_square('gemm', 1024, f, f)
|
||||
|
||||
def test_gemm_small(self):
|
||||
def f(a, b): return a @ b
|
||||
helper_test_generic_square('gemm', 256, f, f)
|
||||
|
||||
def test_gemm_unrolled(self):
|
||||
N = 512
|
||||
def f1(a, b): return a@b.T
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -57,10 +57,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
}
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
def group_float4(grp:List[str]) -> str:
|
||||
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
|
||||
else: return f"{lang.float4}({','.join(g for g in grp)})"
|
||||
|
||||
prekernel: Set[str] = set()
|
||||
kernel = []
|
||||
global_size = []
|
||||
|
@ -103,7 +99,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
else:
|
||||
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
|
||||
depth += 1
|
||||
if uop == UOps.ENDLOOP:
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] == "local" and len(lang.lid):
|
||||
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
||||
kk(lang.barrier)
|
||||
|
@ -116,18 +112,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
pend_close = None
|
||||
depth -= 1
|
||||
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
||||
if uop == UOps.CONST:
|
||||
elif uop == UOps.CONST:
|
||||
assert newvar is not None
|
||||
if args == -math.inf:
|
||||
kk(f"float {newvar} = -INFINITY;")
|
||||
kk(f"{newvar.render(True)} = -INFINITY;")
|
||||
else:
|
||||
kk(f"float {newvar} = {args}f;")
|
||||
if uop == UOps.ALU:
|
||||
kk(f"{newvar.render(True)} = {args}f;")
|
||||
elif uop == UOps.ALU:
|
||||
assert newvar is not None
|
||||
if newvar in vin:
|
||||
kk(f"{newvar} = {code_for_op[args](*vin)};")
|
||||
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
else:
|
||||
kk(f"float {newvar} = {code_for_op[args](*vin)};")
|
||||
# TODO: refactor the next 14 lines
|
||||
if uop == UOps.LOAD:
|
||||
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float:
|
||||
# TODO: merge with CONST?
|
||||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
# nan? inf?
|
||||
|
@ -138,9 +135,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"float {newvar} = {val};")
|
||||
else: kk(f"float {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
|
||||
if uop == UOps.LOAD4:
|
||||
if args.valid.min == 1: kk(f"float {newvar.name} = {val};")
|
||||
else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
|
||||
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
assert newvar.offset is None, "load can't have an offset"
|
||||
if isinstance(bufs[args.i].dtype, ImageDType):
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
|
@ -148,23 +146,27 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
else:
|
||||
val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
|
||||
else: kk(f"float4 {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};")
|
||||
if uop == UOps.STORE:
|
||||
if args[2].min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);")
|
||||
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half({vin[0]}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0]};")
|
||||
if uop == UOps.STORE4:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(vin)});")
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
|
||||
else:
|
||||
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {group_float4(vin)};")
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
|
||||
|
|
|
@ -1,21 +1,36 @@
|
|||
from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple
|
||||
from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple, TypeVar
|
||||
import itertools, math
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
|
||||
from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable, SumNode, ModNode
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); LOAD4 = auto(); STORE4 = auto() # noqa: E702
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
|
||||
class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702
|
||||
|
||||
class Token(NamedTuple):
|
||||
name: str
|
||||
ltype: LocalTypes
|
||||
offset: Optional[int] = None
|
||||
def render(self, with_type=False):
|
||||
if with_type:
|
||||
assert self.offset is None
|
||||
return f"{self.ltype.name} {self.name}"
|
||||
if self.offset is None: return self.name
|
||||
assert self.ltype == LocalTypes.float4
|
||||
return self.name+"."+"xyzw"[int(self.offset)]
|
||||
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.ltype == LocalTypes.float else f"<{self.name}:{self.ltype.name}:{self.offset}>"
|
||||
|
||||
class MemOp(NamedTuple):
|
||||
i: int
|
||||
idx: Variable
|
||||
|
@ -23,10 +38,10 @@ class MemOp(NamedTuple):
|
|||
|
||||
class UOp(NamedTuple):
|
||||
uop: UOps
|
||||
out: Optional[str]
|
||||
vin: List[str]
|
||||
out: Optional[Token]
|
||||
vin: List[Token]
|
||||
arg: Any
|
||||
def __repr__(self): return f"{str(self.uop):20s}: {self.out if self.out is not None else '':10s} {str(self.vin):32s} {self.arg}"
|
||||
def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
|
||||
|
||||
def check_no_mul(test, var):
|
||||
if test == var: return True
|
||||
|
@ -106,32 +121,58 @@ class Linearizer:
|
|||
idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs)
|
||||
# float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required?
|
||||
ret = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render()
|
||||
if DEBUG >= 4: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
|
||||
if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
|
||||
return ret
|
||||
|
||||
def global_buf(self, i, idxs:List[Variable], store=None):
|
||||
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
|
||||
cache: Dict[int, str] = {}
|
||||
store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique
|
||||
# TODO: this is very similar to load
|
||||
def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]:
|
||||
should_upcast = self.supports_float4 and self.can_float4(i)
|
||||
cache: Dict[int, Token] = {}
|
||||
def op(offset):
|
||||
if offset in cache: return cache[offset]
|
||||
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
|
||||
assert self.reduceop is not None
|
||||
reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
if will_merge:
|
||||
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
|
||||
else:
|
||||
cache[offset] = reg
|
||||
return cache[offset]
|
||||
return [op(o) for o in self.offsets(i)]
|
||||
|
||||
def global_load(self, i, idxs:List[Variable]) -> List[Token]:
|
||||
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
|
||||
cache: Dict[int, Token] = {}
|
||||
def op(offset):
|
||||
if offset in cache: return cache[offset]
|
||||
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
|
||||
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
|
||||
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
if will_merge:
|
||||
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
|
||||
else:
|
||||
cache[offset] = reg
|
||||
return cache[offset]
|
||||
return [op(o) for o in self.offsets(i)]
|
||||
|
||||
def global_store(self, i, idxs:List[Variable], store=List[Token]) -> None:
|
||||
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
|
||||
store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique
|
||||
def op(offset):
|
||||
if offset not in store_offset: return
|
||||
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
|
||||
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
|
||||
if store is not None:
|
||||
offsets = []
|
||||
for j in range(0, 4 if will_merge else 1):
|
||||
offsets.append(store[store_offset[offset+j]])
|
||||
del store_offset[offset+j]
|
||||
self.uop(UOps.STORE4 if will_merge else UOps.STORE, None, offsets, MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
else:
|
||||
reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, f"val{mnum(i)}_{mnum(offset)}", [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
if will_merge:
|
||||
for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j]
|
||||
if will_merge:
|
||||
out_tokens = [store[store_offset[offset+j]] for j in range(4)]
|
||||
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
|
||||
var = Token(store[store_offset[offset]].name, LocalTypes.float4)
|
||||
else:
|
||||
cache[offset] = reg
|
||||
return cache[offset]
|
||||
return [op(o) for o in self.offsets(i)]
|
||||
var = self.uop(UOps.CAST, Token(store[store_offset[offset]].name+"_f4", LocalTypes.float4), out_tokens)
|
||||
else:
|
||||
var = store[store_offset[offset]]
|
||||
for j in range(0, 4 if will_merge else 1): del store_offset[offset+j]
|
||||
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
for o in self.offsets(i): op(o)
|
||||
|
||||
def linearize(self):
|
||||
# uops
|
||||
|
@ -157,9 +198,9 @@ class Linearizer:
|
|||
|
||||
# ssa
|
||||
_ssa:DefaultDict[str,int] = defaultdict(int)
|
||||
def ssa(name):
|
||||
def ssa(name, ltype=LocalTypes.float) -> Token:
|
||||
_ssa[name] += 1
|
||||
return f"{name}{_ssa[name]-1}"
|
||||
return Token(f"{name}{_ssa[name]-1}", ltype)
|
||||
|
||||
# global loop
|
||||
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1 if i < self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
|
@ -178,14 +219,14 @@ class Linearizer:
|
|||
# reduce op
|
||||
if self.reduceop is not None:
|
||||
# define accumulator
|
||||
acc = [self.uop(UOps.CONST, ssa('acc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(0)]
|
||||
acc = self.acc(ssa, 0, gl_idxs)
|
||||
|
||||
# reduce loop
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
|
||||
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
|
||||
loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
|
||||
|
@ -195,7 +236,7 @@ class Linearizer:
|
|||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduce:
|
||||
self.global_buf(-1, local_idxs, acc) # store accumulators
|
||||
self.global_store(-1, local_idxs, acc) # store accumulators
|
||||
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
|
@ -207,14 +248,14 @@ class Linearizer:
|
|||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = [self.uop(UOps.CONST, ssa('lacc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(-1)]
|
||||
acc = self.acc(ssa, -1, local_idxs, 'lacc')
|
||||
|
||||
# late reduce loop
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_buf(-1, end_local_idxs)
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True)
|
||||
|
@ -223,29 +264,32 @@ class Linearizer:
|
|||
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
|
||||
|
||||
# store
|
||||
self.global_buf(0, global_idxs, val)
|
||||
self.global_store(0, global_idxs, val)
|
||||
|
||||
# end the global loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
|
||||
|
||||
def uop(self, uop:UOps, out:Optional[str], vin:List[str], arg:Any):
|
||||
self.uops.append(UOp(uop, out, vin, arg))
|
||||
_OT = TypeVar("_OT")
|
||||
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
|
||||
self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
|
||||
if DEBUG >= 4: print(self.uops[-1])
|
||||
return out
|
||||
|
||||
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[str]:
|
||||
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
|
||||
if not isinstance(x, LazyOp): return loaded_buffers[x]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce: return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
if isinstance(x.op, (ReduceOps, FusedOps)):
|
||||
return [self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op]) for val in zip(acc, *values)]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Final, Dict, Callable, Any, List, Optional
|
||||
import functools
|
||||
from llvmlite import ir # type: ignore
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, Token
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -48,7 +48,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
|||
loop_blocks = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type
|
||||
lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type
|
||||
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
|
||||
|
||||
for uop,newvar,vin,args in uops:
|
||||
|
|
|
@ -49,8 +49,9 @@ class LazyNumpyArray:
|
|||
class dtypes:
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
float32: Final[DType] = DType(1, 4, "float", np.float32)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x)]
|
||||
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32, np.dtype(np.int32): dtypes.int32}[np.dtype(x)]
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import functools, itertools, operator, random
|
||||
import functools, itertools, operator, random, time
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
|
@ -49,8 +49,9 @@ class Interpreted:
|
|||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
|
@ -90,9 +91,9 @@ class ASTRunner:
|
|||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs)
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], jit=False) -> Optional[float]:
|
||||
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
|
||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
|
||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(26-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)"))
|
||||
|
@ -152,10 +153,12 @@ class Compiled:
|
|||
# this is the default now
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
|
||||
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
|
||||
elif DEBUG >= 5: print(f"method cache hit : {k.key}")
|
||||
prg = self.method_cache[k.key]
|
||||
else:
|
||||
prg = k.codegen().build(self.runtime)
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(k.bufs)
|
||||
return output.realized
|
||||
|
|
Loading…
Reference in New Issue