diff --git a/mypy.ini b/mypy.ini index 070800245..8a838bd23 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,4 +6,5 @@ check_untyped_defs = True explicit_package_bases = True warn_unreachable = True warn_redundant_casts = True -warn_unused_ignores = True +# NOTE: had to comment this out to make mypy pass on both CI and OSX +#warn_unused_ignores = True diff --git a/tinygrad/jit.py b/tinygrad/jit.py index f0b470d4f..32ae3746e 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional import functools, itertools from tinygrad.helpers import DEBUG, DType, merge_dicts @@ -6,6 +7,7 @@ from tinygrad.tensor import Tensor from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable from dataclasses import dataclass +from weakref import ref, WeakKeyDictionary JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"] @@ -66,19 +68,35 @@ class TinyJit: self.cnt += 1 return self.ret +class PlaceHolder: + def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf) + def to_tuple(self): return (self.size, self.dtype, self._device, self.buftype, self.bufid) + def __hash__(self): return hash(self.to_tuple()) + def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple() + def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, RawBuffer]) -> RawBuffer: + ret = self.ref() + if ret: return ret + if self not in buffer_cache: buffer_cache[self] = self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict())) + return buffer_cache[self] + class _CacheCollector: def __init__(self): - self.cache: Optional[List[JitItem]] = None + self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None + def start(self, var_vals:Optional[Dict[Variable, int]]=None): self.cache = [] + self.placeholders: WeakKeyDictionary[RawBuffer, PlaceHolder] = WeakKeyDictionary() self.var_vals = var_vals if var_vals is not None else {} + def add(self, prg, rawbufs, var_vals): if self.cache is None: return for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" - self.cache.append(JitItem(prg, rawbufs)) + self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) + self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs])) + def finish(self) -> List[JitItem]: if self.cache is None: return [] - ret = self.cache - self.cache = None - return ret + buffer_cache: Dict[PlaceHolder, RawBuffer] = {} + saved_cache, self.cache = self.cache, None + return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache] CacheCollector = _CacheCollector() diff --git a/tinygrad/runtime/ops_shm.py b/tinygrad/runtime/ops_shm.py index 9cbf3af5c..0ebdfe904 100644 --- a/tinygrad/runtime/ops_shm.py +++ b/tinygrad/runtime/ops_shm.py @@ -16,7 +16,7 @@ class RawShmBuffer(RawBufferMapped): fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600) # TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000) - shm.madvise(mmap.MADV_HUGEPAGE) + shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX os.close(fd) super().__init__(size, dtype, shm)