back to 6.54GB for stable diffusion (#2288)
* back to 6.54GB for stable diffusion * cleanups * only outputs, not inputs * err, restore hack for worldpull/2285/head
parent
960535dfb8
commit
6960bcded0
3
mypy.ini
3
mypy.ini
|
@ -6,4 +6,5 @@ check_untyped_defs = True
|
|||
explicit_package_bases = True
|
||||
warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
# NOTE: had to comment this out to make mypy pass on both CI and OSX
|
||||
#warn_unused_ignores = True
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
|
@ -6,6 +7,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from dataclasses import dataclass
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
|
||||
|
||||
|
@ -66,19 +68,35 @@ class TinyJit:
|
|||
self.cnt += 1
|
||||
return self.ret
|
||||
|
||||
class PlaceHolder:
|
||||
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)
|
||||
def to_tuple(self): return (self.size, self.dtype, self._device, self.buftype, self.bufid)
|
||||
def __hash__(self): return hash(self.to_tuple())
|
||||
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
|
||||
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, RawBuffer]) -> RawBuffer:
|
||||
ret = self.ref()
|
||||
if ret: return ret
|
||||
if self not in buffer_cache: buffer_cache[self] = self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
|
||||
return buffer_cache[self]
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self):
|
||||
self.cache: Optional[List[JitItem]] = None
|
||||
self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
|
||||
|
||||
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.cache = []
|
||||
self.placeholders: WeakKeyDictionary[RawBuffer, PlaceHolder] = WeakKeyDictionary()
|
||||
self.var_vals = var_vals if var_vals is not None else {}
|
||||
|
||||
def add(self, prg, rawbufs, var_vals):
|
||||
if self.cache is None: return
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
self.cache.append(JitItem(prg, rawbufs))
|
||||
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
|
||||
|
||||
def finish(self) -> List[JitItem]:
|
||||
if self.cache is None: return []
|
||||
ret = self.cache
|
||||
self.cache = None
|
||||
return ret
|
||||
buffer_cache: Dict[PlaceHolder, RawBuffer] = {}
|
||||
saved_cache, self.cache = self.cache, None
|
||||
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
|
||||
CacheCollector = _CacheCollector()
|
||||
|
|
|
@ -16,7 +16,7 @@ class RawShmBuffer(RawBufferMapped):
|
|||
fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600)
|
||||
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
|
||||
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
|
||||
shm.madvise(mmap.MADV_HUGEPAGE)
|
||||
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
|
||||
os.close(fd)
|
||||
|
||||
super().__init__(size, dtype, shm)
|
||||
|
|
Loading…
Reference in New Issue