1
0
Fork 0

back to 6.54GB for stable diffusion (#2288)

* back to 6.54GB for stable diffusion

* cleanups

* only outputs, not inputs

* err, restore hack for world
pull/2285/head
George Hotz 2023-11-13 16:50:04 -08:00 committed by GitHub
parent 960535dfb8
commit 6960bcded0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 7 deletions

View File

@ -6,4 +6,5 @@ check_untyped_defs = True
explicit_package_bases = True
warn_unreachable = True
warn_redundant_casts = True
warn_unused_ignores = True
# NOTE: had to comment this out to make mypy pass on both CI and OSX
#warn_unused_ignores = True

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
@ -6,6 +7,7 @@ from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from dataclasses import dataclass
from weakref import ref, WeakKeyDictionary
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
@ -66,19 +68,35 @@ class TinyJit:
self.cnt += 1
return self.ret
class PlaceHolder:
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)
def to_tuple(self): return (self.size, self.dtype, self._device, self.buftype, self.bufid)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, RawBuffer]) -> RawBuffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
return buffer_cache[self]
class _CacheCollector:
def __init__(self):
self.cache: Optional[List[JitItem]] = None
self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
self.placeholders: WeakKeyDictionary[RawBuffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.cache.append(JitItem(prg, rawbufs))
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]:
if self.cache is None: return []
ret = self.cache
self.cache = None
return ret
buffer_cache: Dict[PlaceHolder, RawBuffer] = {}
saved_cache, self.cache = self.cache, None
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector()

View File

@ -16,7 +16,7 @@ class RawShmBuffer(RawBufferMapped):
fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600)
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
shm.madvise(mmap.MADV_HUGEPAGE)
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
os.close(fd)
super().__init__(size, dtype, shm)