1
0
Fork 0

stable diffusion < 324ms (#2129)

* stable diffusion < 324ms

* revert swap action

* fix tests due to more sum splitting

* REDUCEOP_SPLIT_THRESHOLD env var
pull/1878/head^2
chenyu 2023-10-24 14:56:12 -04:00 committed by GitHub
parent cea2bc7964
commit 4444e6d4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 8 deletions

View File

@ -260,17 +260,20 @@ class OptimizedKernel(Kernel):
# final optional global upcast
if s1_exists:
s1_div = [upc for upc in [4,3,2,1] if self.full_shape[s1]%upc == 0][0]
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
if s0_exists:
s0_div = [upc for upc in [4,3,2,1] if self.full_shape[s0]%upc == 0][0]
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
# very late (optional) upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
if self.use_tensor_cores and s0_exists and self.full_shape[s0] % 2 == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, 2))
self.exclude_local_upcast += 1
if self.use_tensor_cores and s0_exists:
for upc in [4,2]:
if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
self.exclude_local_upcast += 1
break
# alias buffer
alias_pattern = [0]*(self.global_dims+self.exclude_local_upcast) + [2]*(self.local_dims-self.exclude_local_upcast) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)

View File

@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@ -242,7 +242,7 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]

View File

@ -52,7 +52,7 @@ class MetalBatchExecutor(BasicBatchExecutor):
METAL.mtl_buffers_in_flight.append(command_buffer)
def exec(self, jit_cache: List[Tuple[Any, Any, Any]], updatable_entries):
if self.use_basic_executor: return super().exec(jit_cache, updatable_entries) # No graph is created switch to basic executor.
for i in range((len(jit_cache)+7)//8): self.__do_exec(jit_cache[8*i:8*(i+1)]) # Run in batches with size 8.
for i in range((len(jit_cache)+127)//128): self.__do_exec(jit_cache[128*i:128*(i+1)]) # Run in batches with size 128.
super().recalc_stat(jit_cache)
def unwrap(x):