stable diffusion < 324ms (#2129)
* stable diffusion < 324ms * revert swap action * fix tests due to more sum splitting * REDUCEOP_SPLIT_THRESHOLD env varpull/1878/head^2
parent
cea2bc7964
commit
4444e6d4b3
|
@ -260,17 +260,20 @@ class OptimizedKernel(Kernel):
|
|||
|
||||
# final optional global upcast
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
|
||||
# very late (optional) upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
|
||||
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
|
||||
if self.use_tensor_cores and s0_exists and self.full_shape[s0] % 2 == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, 2))
|
||||
self.exclude_local_upcast += 1
|
||||
if self.use_tensor_cores and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
self.exclude_local_upcast += 1
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
alias_pattern = [0]*(self.global_dims+self.exclude_local_upcast) + [2]*(self.local_dims-self.exclude_local_upcast) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
|
|||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
@ -242,7 +242,7 @@ class LazyBuffer:
|
|||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
|
|
|
@ -52,7 +52,7 @@ class MetalBatchExecutor(BasicBatchExecutor):
|
|||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
def exec(self, jit_cache: List[Tuple[Any, Any, Any]], updatable_entries):
|
||||
if self.use_basic_executor: return super().exec(jit_cache, updatable_entries) # No graph is created switch to basic executor.
|
||||
for i in range((len(jit_cache)+7)//8): self.__do_exec(jit_cache[8*i:8*(i+1)]) # Run in batches with size 8.
|
||||
for i in range((len(jit_cache)+127)//128): self.__do_exec(jit_cache[128*i:128*(i+1)]) # Run in batches with size 128.
|
||||
super().recalc_stat(jit_cache)
|
||||
|
||||
def unwrap(x):
|
||||
|
|
Loading…
Reference in New Issue