1
0
Fork 0
tinygrab/tinygrad/device.py

305 lines
16 KiB
Python

from __future__ import annotations
import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType, round_up
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
# **************** Device ****************
class _Device:
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]:
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0]
if isinstance(ret, type): ret = ret(ix)
return ret
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore
if device_from_env: return device_from_env
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]: return device
except Exception: pass
return "CPU"
Device = _Device()
# **************** base Runner + helpers ****************
class JITRunner:
def __init__(self):
self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
GlobalCounters.kernel_count += num_kernels
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
# **************** Buffer / Allocator ****************
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None):
assert isinstance(dtype, DType)
self.device, self.size, self.dtype = device, size, dtype
self.allocator = Device[self.device].allocator
# TODO: image hack shouldn't be here. where should it be?
if isinstance(dtype, ImageDType) and hasattr(self.allocator, "_cast_image"):
assert opaque is None
row_pitch_items = round_up(dtype.shape[1], 256) * 4
self.size = row_pitch_items * dtype.shape[0] # adjust the size to include the image padding
self._real_buf = self.allocator.alloc(self.size * dtype.itemsize)
self._buf = self.allocator._cast_image(self._real_buf, dtype, row_pitch_items * dtype.itemsize)
else:
self._buf = opaque if opaque is not None else self.allocator.alloc(size * dtype.itemsize)
# TODO: mem_used for all devices
if self.device == Device.DEFAULT: GlobalCounters.mem_used += self.size * self.dtype.itemsize
def __del__(self):
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
if isinstance(self.dtype, ImageDType):
self.allocator._free(self._buf)
self.allocator.free(self._real_buf, self.size * self.dtype.itemsize)
else:
self.allocator.free(self._buf, self.size * self.dtype.itemsize)
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
@staticmethod
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
def toCPU(self) -> np.ndarray:
# zero copy with as_buffer
if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf}))
ret = np.empty(self.size, self.dtype.np)
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
class _BufferCopy(JITRunner):
# TODO: make wait work
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
dest, src = rawbufs
assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch"
if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}")
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
# fast path, used on HIP between GPUs
dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
return
if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
return
if hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
BufferCopy = _BufferCopy()
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int):
assert size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size)
def _alloc(self, size:int): raise NotImplementedError("need alloc")
def free(self, opaque, size:int): self._free(opaque) # if you are returning a Python object, you don't need a free
def _free(self, opaque): pass
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
class LRUAllocator(Allocator): # pylint: disable=abstract-method
def __init__(self): self.cache: Dict[int, Any] = defaultdict(list)
def alloc(self, size:int):
if len(c := self.cache[size]): return c.pop()
try:
return super().alloc(size)
except MemoryError:
self.free_cache()
return super().alloc(size)
def free_cache(self):
for opaques in self.cache.values():
for opaque in opaques: self._free(opaque)
opaques.clear()
def free(self, opaque:Any, size:int):
if getenv("LRU", 1): self.cache[size].append(opaque)
else: self._free(opaque)
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int): return (ctypes.c_uint8 * size)()
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
MallocAllocator = _MallocAllocator()
# **************** for Interpreted Devices ****************
class InterpretedASTRunner(JITRunner):
def __init__(self, ast:LazyOp, fxn:Callable):
super().__init__()
self.fxn = fxn
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
st = time.perf_counter()
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
return et
class Interpreted:
def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
self.allocator, self.fxn_for_op = allocator, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast)
def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
lines: List[str] = []
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
# TODO: shortcutted store won't work with strides
if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op in BufferOps:
if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return InterpretedASTRunner(ast, tglob['run'])
# **************** for Compiled Devices ****************
class CompiledASTRunner(JITRunner):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None):
super().__init__()
if DEBUG >= 4: print(prg)
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib)
return self
def launch_dims(self, var_vals):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
return global_size, local_size
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*[x._buf for x in rawbufs], *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et
class Compiled:
def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None):
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph
def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledASTRunner:
k.linearize()
src, runtime_args = self.renderer(to_function_name(k.name), k.uops)
return CompiledASTRunner(k.ast, k.name, src, k.global_size, k.local_size, runtime_args).build(self.compiler, self.runtime)
def get_linearizer(self, ast:LazyOp) -> Linearizer:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
kb = Linearizer(ast, self.linearizer_opts)
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
return k
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast))