1
0
Fork 0

move optimize_local_size (#2221)

* move optimize_local_size

* interpret_ast
pull/2225/head
George Hotz 2023-11-05 21:00:52 -08:00 committed by GitHub
parent c60c3b467a
commit 2f7aab3d13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 67 additions and 76 deletions

View File

@ -6,7 +6,7 @@ from models.efficientnet import EfficientNet
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.jit import CacheCollector

View File

@ -8,7 +8,7 @@ np.set_printoptions(linewidth=200)
from typing import Optional, Dict
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear

View File

@ -18,7 +18,7 @@ from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit

View File

@ -14,7 +14,7 @@ from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable, sym_infer

View File

@ -3,7 +3,7 @@ import os
import numpy as np
import time, torch, torch.mps
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.ops import Device

View File

@ -4,7 +4,7 @@ from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
def print_objects():
#gc.collect()

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from typing import Union
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device

View File

@ -17,7 +17,7 @@ import onnx
import numpy as np
import tinygrad.graph as graph
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, CacheCollector
import pyopencl as cl

View File

@ -2,7 +2,7 @@ import unittest
from tinygrad.helpers import prod
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import CacheCollector
class TestCopy(unittest.TestCase):

View File

@ -3,7 +3,7 @@ import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.ops import Device

View File

@ -13,7 +13,7 @@ from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad.helpers import getenv
from tinygrad.nn import optim
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
from tinygrad.helpers import GlobalCounters
from tinygrad.lazy import PUSH_PERMUTES
from tinygrad.jit import CacheCollector

View File

@ -2,7 +2,7 @@
import unittest
import numpy as np
from weakref import ref
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.ops import Device

View File

@ -2,8 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, Device
from tinygrad.graph import nm
from tinygrad.ops import Device
from tinygrad.helpers import dtypes
N = 200 # has to be bigger than the cache to fail
@ -50,16 +49,6 @@ class TestAssign(unittest.TestCase):
ba2 = a.lazydata.realized
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
"""
if len(GlobalCounters.cache):
runner, args = GlobalCounters.cache[0]
b0, b1, b2 = args
print(nm(b0), id(b0.cl))
print(nm(b1), id(b1.cl))
print(nm(b2), id(b2.cl))
"""
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?

View File

@ -10,7 +10,7 @@ import time
import numpy as np
np.set_printoptions(linewidth=160)
from tinygrad.ops import Device
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI

View File

@ -1,4 +1,5 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
@ -136,3 +137,15 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, itertools, random
import importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
@ -108,48 +108,49 @@ Device = _Device()
# **************** for Interpreted Buffers ****************
@functools.lru_cache(None)
def interpret_ast(device:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
lines: List[str] = []
f = device.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(device.from_underlying, 'from_underlying')}({ret})" if device.from_underlying else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
self.buffer, self.fxn_for_op, self.to_underlying, self.from_underlying = buffer, fxn_for_op, to_underlying, from_underlying
self.synchronize = lambda: None
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
ret = interpret_ast(self, ast)([x.realized for x in inputs] if inputs else None)
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used?
@ -188,19 +189,6 @@ class ASTRunner:
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
assert self.global_size is not None, "needs a global size to optimize local size"
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return self.clprg([g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size, *test_rawbuffers, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib)
@ -221,7 +209,8 @@ class ASTRunner:
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None:
# TODO: this is copied from get_program
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs)
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size