From 20894991edd4528979231a012f02336873e85dd6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 29 Mar 2023 05:11:02 +0400 Subject: [PATCH] good changes from the M1 Tensor Core project (#730) * good changes * working except llvm * llvm types * nice acc * archprobe * lang.float4 * use self.acc for late acc * fix store bug --- extra/archprobe.py | 63 ++++++++++++ test/external/external_copy_benchmark.py | 26 +++++ test/test_speed_v_torch.py | 33 ++++--- tinygrad/codegen/cstyle.py | 56 +++++------ tinygrad/codegen/linearizer.py | 116 ++++++++++++++++------- tinygrad/codegen/llvmir.py | 4 +- tinygrad/helpers.py | 3 +- tinygrad/ops.py | 13 ++- 8 files changed, 231 insertions(+), 83 deletions(-) create mode 100644 extra/archprobe.py create mode 100644 test/external/external_copy_benchmark.py diff --git a/extra/archprobe.py b/extra/archprobe.py new file mode 100644 index 000000000..994826706 --- /dev/null +++ b/extra/archprobe.py @@ -0,0 +1,63 @@ +# copying the kernels from https://github.com/microsoft/ArchProbe into Python +import numpy as np +from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer +from tinygrad.helpers import dtypes +from tqdm import trange, tqdm +from matplotlib import pyplot as plt + +def reg_count(nthread, ngrp, nreg): + reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]) + reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)]) + reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)]) + prg = f"""__kernel void reg_count( + __global float* out_buf, + __private const int niter + ) {{ + {reg_declr} + int i = 0; + for (; i < niter; ++i) {{ + {reg_comp} + }} + i = i >> 31; + {reg_reduce} + }}""" + out_buf = CLBuffer(1, dtypes.float32) + cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32]) + return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 10, wait=True) for _ in range(10)]) + +""" +print("probing registers") +pts = [(nreg, reg_count(1, 1, nreg)) for nreg in trange(1, 257)] # archprobe goes to 512 +plt.plot(*zip(*pts)) +plt.show() +""" + +def buf_cache_hierarchy_pchase(ndata, stride=1): + NCOMP = 16 # 64 byte is under the 128 byte cache line + print("probe", ndata*NCOMP*4) + prg = """__kernel void buf_cache_hierarchy_pchase( + __global int16* src, + __global int* dst, + const int niter + ) { + int idx = 0; + for (int i = 0; i < niter; ++i) { + idx = src[idx].x; + } + *dst = idx; + }""" + idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32) + for i in range(ndata): + idx_buf[i*NCOMP] = (i + stride) % ndata + in_buf = CLBuffer.fromCPU(idx_buf) + out_buf = CLBuffer(1, dtypes.int32) + cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32]) + return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, ndata*4, wait=True) for _ in range(5)]) + +# 768 kb is real +print("probing cache size") +base = buf_cache_hierarchy_pchase(1, 191) +szs = list(range(128, 1024, 128)) + list(range(1024, 16*1024, 1024)) + list(range(16*1024, int(1.5*1024*1024), 16*1024)) #+ list(range(2*1024*1024, 20*1024*1024, 1024*1024)) +pts = [(ndata, (buf_cache_hierarchy_pchase(ndata//64, 136329190282766681843115968953)-base)/ndata) for ndata in tqdm(szs)] +plt.plot(*zip(*pts)) +plt.show() diff --git a/test/external/external_copy_benchmark.py b/test/external/external_copy_benchmark.py new file mode 100644 index 000000000..3525797a6 --- /dev/null +++ b/test/external/external_copy_benchmark.py @@ -0,0 +1,26 @@ +import unittest +from tinygrad.helpers import prod +from tinygrad.lazy import Device +from tinygrad.tensor import Tensor +from tinygrad.ops import GlobalCounters + +class TestCopy(unittest.TestCase): + def test_add1(self): + pts = [] + for i in range(16384, 16384*256, 16384): + t = Tensor.randn(i).realize() + GlobalCounters.cache = [] + t.assign(t+1).realize() + fxn, args = GlobalCounters.cache[0] + GlobalCounters.reset() + def run(): return fxn(args, force_wait=True) + ct = min([run() for _ in range(10)]) + mb = prod(t.shape)*t.dtype.itemsize*2*1e-6 + print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s") + pts.append((mb, mb/ct)) + from matplotlib import pyplot as plt + plt.plot([x[0] for x in pts], [x[1] for x in pts]) + plt.show() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 9753ee330..b0cb4374b 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -44,26 +44,31 @@ def helper_test_speed(f1, *args): global save_ops, save_mem ets = [] ret = None - for _ in range(CNT): + cache_defeat = np.zeros((2048,2048)) + for i in range(CNT): del ret - args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats + + # operation cache defeats + args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # force syncing [x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None] + # clear 32MB global memory cache (CPU and global memory only) + cache_defeat += 1 + + # manual pre sync + if isinstance(args[0], Tensor): Device[args[0].device].synchronize() + else: sync() + GlobalCounters.global_ops = 0 GlobalCounters.global_mem = 0 - if DEBUG >= 4: print("benchmark start") - st = time.monotonic() + st = time.perf_counter() ret = f1(*args) - if isinstance(ret, Tensor): - ret.realize() - Device[ret.device].synchronize() - else: - sync() - et = (time.monotonic() - st) * 1000 - ets.append(et) - if DEBUG >= 4: print("benchmark stop") + if isinstance(ret, Tensor): Device[ret.device].synchronize() + else: sync() + et = (time.perf_counter() - st) * 1000 + if i >= 1: ets.append(et) # not the first run / one used for OPTLOCAL if GlobalCounters.global_ops: save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem return ret.cpu().numpy(), np.min(ets) @@ -173,6 +178,10 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a @ b helper_test_generic_square('gemm', 1024, f, f) + def test_gemm_small(self): + def f(a, b): return a @ b + helper_test_generic_square('gemm', 256, f, f) + def test_gemm_unrolled(self): N = 512 def f1(a, b): return a@b.T diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 0d512771b..f20afdbfb 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -1,8 +1,8 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union import math, collections -from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer +from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps -from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored +from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored from tinygrad.runtime.lib import RawConst from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode from tinygrad.lazy import LazyBuffer @@ -57,10 +57,6 @@ code_for_op: Final[Dict[Op, Callable]] = { } def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: - def group_float4(grp:List[str]) -> str: - if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0] - else: return f"{lang.float4}({','.join(g for g in grp)})" - prekernel: Set[str] = set() kernel = [] global_size = [] @@ -103,7 +99,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan else: kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{") depth += 1 - if uop == UOps.ENDLOOP: + elif uop == UOps.ENDLOOP: if args[1] == "local" and len(lang.lid): # TODO: this is a bit of a hack. the local loop isn't real on the GPU kk(lang.barrier) @@ -116,18 +112,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan pend_close = None depth -= 1 kk("}"*len(args[0]) + f" /* {args[1]} */") - if uop == UOps.CONST: + elif uop == UOps.CONST: + assert newvar is not None if args == -math.inf: - kk(f"float {newvar} = -INFINITY;") + kk(f"{newvar.render(True)} = -INFINITY;") else: - kk(f"float {newvar} = {args}f;") - if uop == UOps.ALU: + kk(f"{newvar.render(True)} = {args}f;") + elif uop == UOps.ALU: + assert newvar is not None if newvar in vin: - kk(f"{newvar} = {code_for_op[args](*vin)};") + kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};") else: - kk(f"float {newvar} = {code_for_op[args](*vin)};") - # TODO: refactor the next 14 lines - if uop == UOps.LOAD: + kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};") + elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float: # TODO: merge with CONST? if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst): # nan? inf? @@ -138,9 +135,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan else: val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args.valid.min == 1: kk(f"float {newvar} = {val};") - else: kk(f"float {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;") - if uop == UOps.LOAD4: + if args.valid.min == 1: kk(f"float {newvar.name} = {val};") + else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;") + elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4: + assert newvar.offset is None, "load can't have an offset" if isinstance(bufs[args.i].dtype, ImageDType): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid) @@ -148,23 +146,27 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan else: val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]" # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args[2].min == 1: kk(f"float4 {newvar} = {val};") - else: kk(f"float4 {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};") - if uop == UOps.STORE: + if args[2].min == 1: kk(f"{newvar.render(True)} = {val};") + else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);") + elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)): assert args.valid.min == 1, "store must be valid" if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: - kk(f"vstore_half({vin[0]}, {args.idx.render(render_cl)}, {bufnames[args.i]});") + kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});") else: - kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0]};") - if uop == UOps.STORE4: + kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};") + elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4: + kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});") + elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None: assert args.valid.min == 1, "store must be valid" if isinstance(bufs[args[0]].dtype, ImageDType): idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2]) - kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(vin)});") + kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});") else: - kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {group_float4(vin)};") - if uop == UOps.DEFINE_LOCAL: + kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};") + elif uop == UOps.DEFINE_LOCAL: kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];") + else: + raise RuntimeError(f"failed to render {uop}") buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else ("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f97cf9c66..423e3e3f4 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,21 +1,36 @@ -from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple +from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple, TypeVar import itertools, math from collections import defaultdict from enum import Enum, auto -from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType +from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape from tinygrad.shape.symbolic import Variable, SumNode, ModNode -class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); LOAD4 = auto(); STORE4 = auto() # noqa: E702 +class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702 class LocalBuffer(NamedTuple): dtype: DType = dtypes.float32 realized: None = None +class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702 + +class Token(NamedTuple): + name: str + ltype: LocalTypes + offset: Optional[int] = None + def render(self, with_type=False): + if with_type: + assert self.offset is None + return f"{self.ltype.name} {self.name}" + if self.offset is None: return self.name + assert self.ltype == LocalTypes.float4 + return self.name+"."+"xyzw"[int(self.offset)] + def __repr__(self): return f"<{self.name}>" if self.offset is None and self.ltype == LocalTypes.float else f"<{self.name}:{self.ltype.name}:{self.offset}>" + class MemOp(NamedTuple): i: int idx: Variable @@ -23,10 +38,10 @@ class MemOp(NamedTuple): class UOp(NamedTuple): uop: UOps - out: Optional[str] - vin: List[str] + out: Optional[Token] + vin: List[Token] arg: Any - def __repr__(self): return f"{str(self.uop):20s}: {self.out if self.out is not None else '':10s} {str(self.vin):32s} {self.arg}" + def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}" def check_no_mul(test, var): if test == var: return True @@ -106,32 +121,58 @@ class Linearizer: idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs) # float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required? ret = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render() - if DEBUG >= 4: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4) + if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4) return ret - def global_buf(self, i, idxs:List[Variable], store=None): - should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 - cache: Dict[int, str] = {} - store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique + # TODO: this is very similar to load + def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]: + should_upcast = self.supports_float4 and self.can_float4(i) + cache: Dict[int, Token] = {} def op(offset): if offset in cache: return cache[offset] + will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) + assert self.reduceop is not None + reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + if will_merge: + for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j) + else: + cache[offset] = reg + return cache[offset] + return [op(o) for o in self.offsets(i)] + + def global_load(self, i, idxs:List[Variable]) -> List[Token]: + should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 + cache: Dict[int, Token] = {} + def op(offset): + if offset in cache: return cache[offset] + will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) + assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" + reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) + if will_merge: + for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j) + else: + cache[offset] = reg + return cache[offset] + return [op(o) for o in self.offsets(i)] + + def global_store(self, i, idxs:List[Variable], store=List[Token]) -> None: + should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 + store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique + def op(offset): if offset not in store_offset: return will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" - if store is not None: - offsets = [] - for j in range(0, 4 if will_merge else 1): - offsets.append(store[store_offset[offset+j]]) - del store_offset[offset+j] - self.uop(UOps.STORE4 if will_merge else UOps.STORE, None, offsets, MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - else: - reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, f"val{mnum(i)}_{mnum(offset)}", [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - if will_merge: - for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j] + if will_merge: + out_tokens = [store[store_offset[offset+j]] for j in range(4)] + if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens): + var = Token(store[store_offset[offset]].name, LocalTypes.float4) else: - cache[offset] = reg - return cache[offset] - return [op(o) for o in self.offsets(i)] + var = self.uop(UOps.CAST, Token(store[store_offset[offset]].name+"_f4", LocalTypes.float4), out_tokens) + else: + var = store[store_offset[offset]] + for j in range(0, 4 if will_merge else 1): del store_offset[offset+j] + self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) + for o in self.offsets(i): op(o) def linearize(self): # uops @@ -157,9 +198,9 @@ class Linearizer: # ssa _ssa:DefaultDict[str,int] = defaultdict(int) - def ssa(name): + def ssa(name, ltype=LocalTypes.float) -> Token: _ssa[name] += 1 - return f"{name}{_ssa[name]-1}" + return Token(f"{name}{_ssa[name]-1}", ltype) # global loop global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1 if i < self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] @@ -178,14 +219,14 @@ class Linearizer: # reduce op if self.reduceop is not None: # define accumulator - acc = [self.uop(UOps.CONST, ssa('acc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(0)] + acc = self.acc(ssa, 0, gl_idxs) # reduce loop reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce")) # load earlybufs - loaded_buffers.update({b:self.global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) + loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) # run early AST (with reduce) self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True) @@ -195,7 +236,7 @@ class Linearizer: # end the local loop, do the local reduce if self.group_for_reduce: - self.global_buf(-1, local_idxs, acc) # store accumulators + self.global_store(-1, local_idxs, acc) # store accumulators self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs # if any group_for_reduce items aren't reduces, upcast them here @@ -207,14 +248,14 @@ class Linearizer: # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = [self.uop(UOps.CONST, ssa('lacc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(-1)] + acc = self.acc(ssa, -1, local_idxs, 'lacc') # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce")) # load localbufs - loaded_buffers["LOCAL_BUFFER"] = self.global_buf(-1, end_local_idxs) + loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) @@ -223,29 +264,32 @@ class Linearizer: self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) # load latebufs - loaded_buffers.update({b:self.global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) + loaded_buffers.update({b:self.global_load(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) # run late AST val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) # store - self.global_buf(0, global_idxs, val) + self.global_store(0, global_idxs, val) # end the global loop self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global")) - def uop(self, uop:UOps, out:Optional[str], vin:List[str], arg:Any): - self.uops.append(UOp(uop, out, vin, arg)) + _OT = TypeVar("_OT") + def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT: + self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg)) if DEBUG >= 4: print(self.uops[-1]) return out - def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[str]: + def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]: if not isinstance(x, LazyOp): return loaded_buffers[x] if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op if x.op in ReduceOps and not do_reduce: return acc # MULACC fusion. TODO: this is copied from Interpreted if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL: x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg) + if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL: + x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg) values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] if isinstance(x.op, (ReduceOps, FusedOps)): return [self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op]) for val in zip(acc, *values)] diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index 7f43bfb2f..1493e3033 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -1,7 +1,7 @@ from typing import Final, Dict, Callable, Any, List, Optional import functools from llvmlite import ir # type: ignore -from tinygrad.codegen.linearizer import Linearizer, UOps, UOp +from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, Token from tinygrad.helpers import dtypes from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps from tinygrad.lazy import LazyBuffer @@ -48,7 +48,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str: loop_blocks = [] reduce_phis: List = [] # TODO: newvar probably shouldn't be optional - lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type + lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] for uop,newvar,vin,args in uops: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e350ebfab..f270273b5 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -49,8 +49,9 @@ class LazyNumpyArray: class dtypes: float16: Final[DType] = DType(0, 2, "half", np.float16) float32: Final[DType] = DType(1, 4, "float", np.float32) + int32: Final[DType] = DType(1, 4, "int", np.int32) @staticmethod - def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x)] + def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32, np.dtype(np.int32): dtypes.int32}[np.dtype(x)] class GlobalCounters: global_ops: ClassVar[int] = 0 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e79d2567a..5e53b8004 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools, itertools, operator, random +import functools, itertools, operator, random, time from enum import Enum, auto from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored @@ -49,8 +49,9 @@ class Interpreted: if context is None: context = dict() if not created_context and ast in context: return context[ast] srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src] + if DEBUG >= 3: st = time.perf_counter() ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) - if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "") + if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "") if not created_context: context[ast] = ret if output is not None and output.output_buffer is not None: assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype @@ -90,9 +91,9 @@ class ASTRunner: if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs)) return self(rawbufs) - def __call__(self, rawbufs:List[RawBuffer], jit=False) -> Optional[float]: + def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]: if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2)) - if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et + if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et if DEBUG >= 2: print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(26-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)")) @@ -152,10 +153,12 @@ class Compiled: # this is the default now if getenv("ENABLE_METHOD_CACHE", 1): if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime) - elif DEBUG >= 4: print(f"method cache hit : {k.key}") + elif DEBUG >= 5: print(f"method cache hit : {k.key}") prg = self.method_cache[k.key] else: prg = k.codegen().build(self.runtime) + if prg.name == getenv("PRINT_PRG", ''): print(prg.prg) + prg.exec(k.bufs) return output.realized