1
0
Fork 0

back to 6.54GB for stable diffusion (#2288)

* back to 6.54GB for stable diffusion

* cleanups

* only outputs, not inputs

* err, restore hack for world
pull/2285/head
George Hotz 2023-11-13 16:50:04 -08:00 committed by GitHub
parent 960535dfb8
commit 6960bcded0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 7 deletions

View File

@ -6,4 +6,5 @@ check_untyped_defs = True
explicit_package_bases = True explicit_package_bases = True
warn_unreachable = True warn_unreachable = True
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_ignores = True # NOTE: had to comment this out to make mypy pass on both CI and OSX
#warn_unused_ignores = True

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts from tinygrad.helpers import DEBUG, DType, merge_dicts
@ -6,6 +7,7 @@ from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable from tinygrad.shape.symbolic import Variable
from dataclasses import dataclass from dataclasses import dataclass
from weakref import ref, WeakKeyDictionary
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"] JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
@ -66,19 +68,35 @@ class TinyJit:
self.cnt += 1 self.cnt += 1
return self.ret return self.ret
class PlaceHolder:
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)
def to_tuple(self): return (self.size, self.dtype, self._device, self.buftype, self.bufid)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, RawBuffer]) -> RawBuffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
return buffer_cache[self]
class _CacheCollector: class _CacheCollector:
def __init__(self): def __init__(self):
self.cache: Optional[List[JitItem]] = None self.cache: Optional[List[Tuple[ASTRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None): def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = [] self.cache = []
self.placeholders: WeakKeyDictionary[RawBuffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {} self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals): def add(self, prg, rawbufs, var_vals):
if self.cache is None: return if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.cache.append(JitItem(prg, rawbufs)) self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]: def finish(self) -> List[JitItem]:
if self.cache is None: return [] if self.cache is None: return []
ret = self.cache buffer_cache: Dict[PlaceHolder, RawBuffer] = {}
self.cache = None saved_cache, self.cache = self.cache, None
return ret return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector() CacheCollector = _CacheCollector()

View File

@ -16,7 +16,7 @@ class RawShmBuffer(RawBufferMapped):
fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600) fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600)
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need # TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000) shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
shm.madvise(mmap.MADV_HUGEPAGE) shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
os.close(fd) os.close(fd)
super().__init__(size, dtype, shm) super().__init__(size, dtype, shm)