1
0
Fork 0

move optimize_local_size (#2221)

* move optimize_local_size

* interpret_ast
pull/2225/head
George Hotz 2023-11-05 21:00:52 -08:00 committed by GitHub
parent c60c3b467a
commit 2f7aab3d13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 67 additions and 76 deletions

View File

@ -6,7 +6,7 @@ from models.efficientnet import EfficientNet
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.jit import CacheCollector from tinygrad.jit import CacheCollector

View File

@ -8,7 +8,7 @@ np.set_printoptions(linewidth=200)
from typing import Optional, Dict from typing import Optional, Dict
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.ops import Device from tinygrad.ops import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear from tinygrad.nn import Embedding, Linear

View File

@ -18,7 +18,7 @@ from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.ops import Device from tinygrad.ops import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.shape.symbolic import Node from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit from tinygrad.jit import TinyJit

View File

@ -14,7 +14,7 @@ from tinygrad.ops import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable, sym_infer from tinygrad.shape.symbolic import Variable, sym_infer

View File

@ -3,7 +3,7 @@ import os
import numpy as np import numpy as np
import time, torch, torch.mps import time, torch, torch.mps
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit from tinygrad.jit import TinyJit
from tinygrad.ops import Device from tinygrad.ops import Device

View File

@ -4,7 +4,7 @@ from tinygrad.helpers import prod
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import CLBuffer from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
def print_objects(): def print_objects():
#gc.collect() #gc.collect()

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from typing import Union from typing import Union
from tinygrad.helpers import prod, getenv, DEBUG, dtypes from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device from tinygrad.ops import Device

View File

@ -17,7 +17,7 @@ import onnx
import numpy as np import numpy as np
import tinygrad.graph as graph import tinygrad.graph as graph
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, CacheCollector from tinygrad.jit import TinyJit, CacheCollector
import pyopencl as cl import pyopencl as cl

View File

@ -2,7 +2,7 @@ import unittest
from tinygrad.helpers import prod from tinygrad.helpers import prod
from tinygrad.ops import Device from tinygrad.ops import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.jit import CacheCollector from tinygrad.jit import CacheCollector
class TestCopy(unittest.TestCase): class TestCopy(unittest.TestCase):

View File

@ -3,7 +3,7 @@ import unittest, gc
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict from tinygrad.nn.state import get_state_dict
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod from tinygrad.helpers import dtypes, prod
from tinygrad.ops import Device from tinygrad.ops import Device

View File

@ -13,7 +13,7 @@ from tinygrad.tensor import Tensor, Device
from tinygrad import nn from tinygrad import nn
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps from tinygrad.helpers import GlobalCounters
from tinygrad.lazy import PUSH_PERMUTES from tinygrad.lazy import PUSH_PERMUTES
from tinygrad.jit import CacheCollector from tinygrad.jit import CacheCollector

View File

@ -2,7 +2,7 @@
import unittest import unittest
import numpy as np import numpy as np
from weakref import ref from weakref import ref
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod from tinygrad.helpers import dtypes, prod
from tinygrad.ops import Device from tinygrad.ops import Device

View File

@ -2,8 +2,7 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, Device from tinygrad.ops import Device
from tinygrad.graph import nm
from tinygrad.helpers import dtypes from tinygrad.helpers import dtypes
N = 200 # has to be bigger than the cache to fail N = 200 # has to be bigger than the cache to fail
@ -50,16 +49,6 @@ class TestAssign(unittest.TestCase):
ba2 = a.lazydata.realized ba2 = a.lazydata.realized
# NOTE: don't test that it's assigned # NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1 #assert ba1 == ba2 and ba1 != bb1
"""
if len(GlobalCounters.cache):
runner, args = GlobalCounters.cache[0]
b0, b1, b2 = args
print(nm(b0), id(b0.cl))
print(nm(b1), id(b1.cl))
print(nm(b2), id(b2.cl))
"""
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer? # TODO: is there a way to sneak in a permute such that it returns the wrong answer?

View File

@ -10,7 +10,7 @@ import time
import numpy as np import numpy as np
np.set_printoptions(linewidth=160) np.set_printoptions(linewidth=160)
from tinygrad.ops import Device from tinygrad.ops import Device
from tinygrad.ops import GlobalCounters from tinygrad.helpers import GlobalCounters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI from tinygrad.helpers import colored, getenv, CI

View File

@ -1,4 +1,5 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random
from tinygrad.lazy import vars_from_ast from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
@ -136,3 +137,15 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts) if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0] return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import importlib, inspect, functools, pathlib, itertools, random import importlib, inspect, functools, pathlib
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
@ -108,48 +108,49 @@ Device = _Device()
# **************** for Interpreted Buffers **************** # **************** for Interpreted Buffers ****************
@functools.lru_cache(None)
def interpret_ast(device:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
lines: List[str] = []
f = device.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(device.from_underlying, 'from_underlying')}({ret})" if device.from_underlying else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
class Interpreted: class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None): def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
self.buffer, self.fxn_for_op, self.to_underlying, self.from_underlying = buffer, fxn_for_op, to_underlying, from_underlying self.buffer, self.fxn_for_op, self.to_underlying, self.from_underlying = buffer, fxn_for_op, to_underlying, from_underlying
self.synchronize = lambda: None self.synchronize = lambda: None
self.codegen = None self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs): def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
ret = interpret_ast(self, ast)([x.realized for x in inputs] if inputs else None) if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype. ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used? # TODO: is this used?
@ -188,19 +189,6 @@ class ASTRunner:
if DEBUG >= 4: print(prg) if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
assert self.global_size is not None, "needs a global size to optimize local size"
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return self.clprg([g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size, *test_rawbuffers, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
def build(self, compiler, runtime): def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib) self.clprg = runtime(self.name, self.lib)
@ -221,7 +209,8 @@ class ASTRunner:
global_size, local_size = self.launch_dims(var_vals) global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None: if global_size is not None and local_size is None:
# TODO: this is copied from get_program # TODO: this is copied from get_program
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs) from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy() lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size if global_size: lra['global_size'] = global_size