1
0
Fork 0

good changes from the M1 Tensor Core project (#730)

* good changes

* working except llvm

* llvm types

* nice acc

* archprobe

* lang.float4

* use self.acc for late acc

* fix store bug
pull/733/head
George Hotz 2023-03-29 05:11:02 +04:00 committed by GitHub
parent 156640e90d
commit 20894991ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 231 additions and 83 deletions

63
extra/archprobe.py 100644
View File

@ -0,0 +1,63 @@
# copying the kernels from https://github.com/microsoft/ArchProbe into Python
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer
from tinygrad.helpers import dtypes
from tqdm import trange, tqdm
from matplotlib import pyplot as plt
def reg_count(nthread, ngrp, nreg):
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
prg = f"""__kernel void reg_count(
__global float* out_buf,
__private const int niter
) {{
{reg_declr}
int i = 0;
for (; i < niter; ++i) {{
{reg_comp}
}}
i = i >> 31;
{reg_reduce}
}}"""
out_buf = CLBuffer(1, dtypes.float32)
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 10, wait=True) for _ in range(10)])
"""
print("probing registers")
pts = [(nreg, reg_count(1, 1, nreg)) for nreg in trange(1, 257)] # archprobe goes to 512
plt.plot(*zip(*pts))
plt.show()
"""
def buf_cache_hierarchy_pchase(ndata, stride=1):
NCOMP = 16 # 64 byte is under the 128 byte cache line
print("probe", ndata*NCOMP*4)
prg = """__kernel void buf_cache_hierarchy_pchase(
__global int16* src,
__global int* dst,
const int niter
) {
int idx = 0;
for (int i = 0; i < niter; ++i) {
idx = src[idx].x;
}
*dst = idx;
}"""
idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32)
for i in range(ndata):
idx_buf[i*NCOMP] = (i + stride) % ndata
in_buf = CLBuffer.fromCPU(idx_buf)
out_buf = CLBuffer(1, dtypes.int32)
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, ndata*4, wait=True) for _ in range(5)])
# 768 kb is real
print("probing cache size")
base = buf_cache_hierarchy_pchase(1, 191)
szs = list(range(128, 1024, 128)) + list(range(1024, 16*1024, 1024)) + list(range(16*1024, int(1.5*1024*1024), 16*1024)) #+ list(range(2*1024*1024, 20*1024*1024, 1024*1024))
pts = [(ndata, (buf_cache_hierarchy_pchase(ndata//64, 136329190282766681843115968953)-base)/ndata) for ndata in tqdm(szs)]
plt.plot(*zip(*pts))
plt.show()

View File

@ -0,0 +1,26 @@
import unittest
from tinygrad.helpers import prod
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
class TestCopy(unittest.TestCase):
def test_add1(self):
pts = []
for i in range(16384, 16384*256, 16384):
t = Tensor.randn(i).realize()
GlobalCounters.cache = []
t.assign(t+1).realize()
fxn, args = GlobalCounters.cache[0]
GlobalCounters.reset()
def run(): return fxn(args, force_wait=True)
ct = min([run() for _ in range(10)])
mb = prod(t.shape)*t.dtype.itemsize*2*1e-6
print(f"{mb*1e3:.2f} kB, {ct*1e3:.2f} ms, {mb/ct:.2f} MB/s")
pts.append((mb, mb/ct))
from matplotlib import pyplot as plt
plt.plot([x[0] for x in pts], [x[1] for x in pts])
plt.show()
if __name__ == '__main__':
unittest.main()

View File

@ -44,26 +44,31 @@ def helper_test_speed(f1, *args):
global save_ops, save_mem
ets = []
ret = None
for _ in range(CNT):
cache_defeat = np.zeros((2048,2048))
for i in range(CNT):
del ret
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
# operation cache defeats
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
# force syncing
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
# clear 32MB global memory cache (CPU and global memory only)
cache_defeat += 1
# manual pre sync
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
else: sync()
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
if DEBUG >= 4: print("benchmark start")
st = time.monotonic()
st = time.perf_counter()
ret = f1(*args)
if isinstance(ret, Tensor):
ret.realize()
Device[ret.device].synchronize()
else:
sync()
et = (time.monotonic() - st) * 1000
ets.append(et)
if DEBUG >= 4: print("benchmark stop")
if isinstance(ret, Tensor): Device[ret.device].synchronize()
else: sync()
et = (time.perf_counter() - st) * 1000
if i >= 1: ets.append(et) # not the first run / one used for OPTLOCAL
if GlobalCounters.global_ops:
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
return ret.cpu().numpy(), np.min(ets)
@ -173,6 +178,10 @@ class TestSpeed(unittest.TestCase):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 1024, f, f)
def test_gemm_small(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 256, f, f)
def test_gemm_unrolled(self):
N = 512
def f1(a, b): return a@b.T

View File

@ -1,8 +1,8 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@ -57,10 +57,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
}
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def group_float4(grp:List[str]) -> str:
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
else: return f"{lang.float4}({','.join(g for g in grp)})"
prekernel: Set[str] = set()
kernel = []
global_size = []
@ -103,7 +99,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
depth += 1
if uop == UOps.ENDLOOP:
elif uop == UOps.ENDLOOP:
if args[1] == "local" and len(lang.lid):
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(lang.barrier)
@ -116,18 +112,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
if uop == UOps.CONST:
elif uop == UOps.CONST:
assert newvar is not None
if args == -math.inf:
kk(f"float {newvar} = -INFINITY;")
kk(f"{newvar.render(True)} = -INFINITY;")
else:
kk(f"float {newvar} = {args}f;")
if uop == UOps.ALU:
kk(f"{newvar.render(True)} = {args}f;")
elif uop == UOps.ALU:
assert newvar is not None
if newvar in vin:
kk(f"{newvar} = {code_for_op[args](*vin)};")
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
else:
kk(f"float {newvar} = {code_for_op[args](*vin)};")
# TODO: refactor the next 14 lines
if uop == UOps.LOAD:
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float:
# TODO: merge with CONST?
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
# nan? inf?
@ -138,9 +135,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args.valid.min == 1: kk(f"float {newvar} = {val};")
else: kk(f"float {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
if uop == UOps.LOAD4:
if args.valid.min == 1: kk(f"float {newvar.name} = {val};")
else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4:
assert newvar.offset is None, "load can't have an offset"
if isinstance(bufs[args.i].dtype, ImageDType):
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
@ -148,23 +146,27 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
else: kk(f"float4 {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};")
if uop == UOps.STORE:
if args[2].min == 1: kk(f"{newvar.render(True)} = {val};")
else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);")
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
assert args.valid.min == 1, "store must be valid"
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half({vin[0]}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0]};")
if uop == UOps.STORE4:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
assert args.valid.min == 1, "store must be valid"
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(vin)});")
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
else:
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {group_float4(vin)};")
if uop == UOps.DEFINE_LOCAL:
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
else:
raise RuntimeError(f"failed to render {uop}")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)

View File

@ -1,21 +1,36 @@
from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple
from typing import List, Tuple, Any, Optional, cast, Dict, DefaultDict, NamedTuple, TypeVar
import itertools, math
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable, SumNode, ModNode
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); LOAD4 = auto(); STORE4 = auto() # noqa: E702
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702
class LocalBuffer(NamedTuple):
dtype: DType = dtypes.float32
realized: None = None
class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702
class Token(NamedTuple):
name: str
ltype: LocalTypes
offset: Optional[int] = None
def render(self, with_type=False):
if with_type:
assert self.offset is None
return f"{self.ltype.name} {self.name}"
if self.offset is None: return self.name
assert self.ltype == LocalTypes.float4
return self.name+"."+"xyzw"[int(self.offset)]
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.ltype == LocalTypes.float else f"<{self.name}:{self.ltype.name}:{self.offset}>"
class MemOp(NamedTuple):
i: int
idx: Variable
@ -23,10 +38,10 @@ class MemOp(NamedTuple):
class UOp(NamedTuple):
uop: UOps
out: Optional[str]
vin: List[str]
out: Optional[Token]
vin: List[Token]
arg: Any
def __repr__(self): return f"{str(self.uop):20s}: {self.out if self.out is not None else '':10s} {str(self.vin):32s} {self.arg}"
def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
def check_no_mul(test, var):
if test == var: return True
@ -106,32 +121,58 @@ class Linearizer:
idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs)
# float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required?
ret = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render()
if DEBUG >= 4: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
return ret
def global_buf(self, i, idxs:List[Variable], store=None):
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
cache: Dict[int, str] = {}
store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique
# TODO: this is very similar to load
def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]:
should_upcast = self.supports_float4 and self.can_float4(i)
cache: Dict[int, Token] = {}
def op(offset):
if offset in cache: return cache[offset]
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
assert self.reduceop is not None
reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
if will_merge:
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
else:
cache[offset] = reg
return cache[offset]
return [op(o) for o in self.offsets(i)]
def global_load(self, i, idxs:List[Variable]) -> List[Token]:
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
cache: Dict[int, Token] = {}
def op(offset):
if offset in cache: return cache[offset]
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
if will_merge:
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
else:
cache[offset] = reg
return cache[offset]
return [op(o) for o in self.offsets(i)]
def global_store(self, i, idxs:List[Variable], store=List[Token]) -> None:
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique
def op(offset):
if offset not in store_offset: return
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
if store is not None:
offsets = []
for j in range(0, 4 if will_merge else 1):
offsets.append(store[store_offset[offset+j]])
del store_offset[offset+j]
self.uop(UOps.STORE4 if will_merge else UOps.STORE, None, offsets, MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
else:
reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, f"val{mnum(i)}_{mnum(offset)}", [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
if will_merge:
for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j]
if will_merge:
out_tokens = [store[store_offset[offset+j]] for j in range(4)]
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
var = Token(store[store_offset[offset]].name, LocalTypes.float4)
else:
cache[offset] = reg
return cache[offset]
return [op(o) for o in self.offsets(i)]
var = self.uop(UOps.CAST, Token(store[store_offset[offset]].name+"_f4", LocalTypes.float4), out_tokens)
else:
var = store[store_offset[offset]]
for j in range(0, 4 if will_merge else 1): del store_offset[offset+j]
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
for o in self.offsets(i): op(o)
def linearize(self):
# uops
@ -157,9 +198,9 @@ class Linearizer:
# ssa
_ssa:DefaultDict[str,int] = defaultdict(int)
def ssa(name):
def ssa(name, ltype=LocalTypes.float) -> Token:
_ssa[name] += 1
return f"{name}{_ssa[name]-1}"
return Token(f"{name}{_ssa[name]-1}", ltype)
# global loop
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1 if i < self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
@ -178,14 +219,14 @@ class Linearizer:
# reduce op
if self.reduceop is not None:
# define accumulator
acc = [self.uop(UOps.CONST, ssa('acc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(0)]
acc = self.acc(ssa, 0, gl_idxs)
# reduce loop
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
# load earlybufs
loaded_buffers.update({b:self.global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
# run early AST (with reduce)
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
@ -195,7 +236,7 @@ class Linearizer:
# end the local loop, do the local reduce
if self.group_for_reduce:
self.global_buf(-1, local_idxs, acc) # store accumulators
self.global_store(-1, local_idxs, acc) # store accumulators
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs
# if any group_for_reduce items aren't reduces, upcast them here
@ -207,14 +248,14 @@ class Linearizer:
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = [self.uop(UOps.CONST, ssa('lacc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(-1)]
acc = self.acc(ssa, -1, local_idxs, 'lacc')
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
# load localbufs
loaded_buffers["LOCAL_BUFFER"] = self.global_buf(-1, end_local_idxs)
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True)
@ -223,29 +264,32 @@ class Linearizer:
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
# load latebufs
loaded_buffers.update({b:self.global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
loaded_buffers.update({b:self.global_load(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
# run late AST
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
# store
self.global_buf(0, global_idxs, val)
self.global_store(0, global_idxs, val)
# end the global loop
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
def uop(self, uop:UOps, out:Optional[str], vin:List[str], arg:Any):
self.uops.append(UOp(uop, out, vin, arg))
_OT = TypeVar("_OT")
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
if DEBUG >= 4: print(self.uops[-1])
return out
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[str]:
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
if not isinstance(x, LazyOp): return loaded_buffers[x]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
if isinstance(x.op, (ReduceOps, FusedOps)):
return [self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op]) for val in zip(acc, *values)]

View File

@ -1,7 +1,7 @@
from typing import Final, Dict, Callable, Any, List, Optional
import functools
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, Token
from tinygrad.helpers import dtypes
from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps
from tinygrad.lazy import LazyBuffer
@ -48,7 +48,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
loop_blocks = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[str], Any] = {} # this Any is an llvm type
lvars: Dict[Optional[Token], Any] = {} # this Any is an llvm type
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
for uop,newvar,vin,args in uops:

View File

@ -49,8 +49,9 @@ class LazyNumpyArray:
class dtypes:
float16: Final[DType] = DType(0, 2, "half", np.float16)
float32: Final[DType] = DType(1, 4, "float", np.float32)
int32: Final[DType] = DType(1, 4, "int", np.int32)
@staticmethod
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x)]
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32, np.dtype(np.int32): dtypes.int32}[np.dtype(x)]
class GlobalCounters:
global_ops: ClassVar[int] = 0

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import functools, itertools, operator, random
import functools, itertools, operator, random, time
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored
@ -49,8 +49,9 @@ class Interpreted:
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
if DEBUG >= 3: st = time.perf_counter()
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output is not None and output.output_buffer is not None:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
@ -90,9 +91,9 @@ class ASTRunner:
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs)
def __call__(self, rawbufs:List[RawBuffer], jit=False) -> Optional[float]:
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(26-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)"))
@ -152,10 +153,12 @@ class Compiled:
# this is the default now
if getenv("ENABLE_METHOD_CACHE", 1):
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
elif DEBUG >= 5: print(f"method cache hit : {k.key}")
prg = self.method_cache[k.key]
else:
prg = k.codegen().build(self.runtime)
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(k.bufs)
return output.realized