parent
c60c3b467a
commit
2f7aab3d13
|
@ -6,7 +6,7 @@ from models.efficientnet import EfficientNet
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ np.set_printoptions(linewidth=200)
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
|
from tinygrad.helpers import Timing, getenv, dtypes, DEBUG
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import Embedding, Linear
|
from tinygrad.nn import Embedding, Linear
|
||||||
|
|
|
@ -18,7 +18,7 @@ from tinygrad.nn.state import get_state_dict
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.shape.symbolic import Node
|
from tinygrad.shape.symbolic import Node
|
||||||
from extra.lr_scheduler import OneCycleLR
|
from extra.lr_scheduler import OneCycleLR
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
|
@ -14,7 +14,7 @@ from tinygrad.ops import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import Embedding, Linear
|
from tinygrad.nn import Embedding, Linear
|
||||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time, torch, torch.mps
|
import time, torch, torch.mps
|
||||||
|
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
|
|
|
@ -4,7 +4,7 @@ from tinygrad.helpers import prod
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.runtime.ops_gpu import CLBuffer
|
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
|
|
||||||
def print_objects():
|
def print_objects():
|
||||||
#gc.collect()
|
#gc.collect()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
|
|
|
@ -17,7 +17,7 @@ import onnx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import tinygrad.graph as graph
|
import tinygrad.graph as graph
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.jit import TinyJit, CacheCollector
|
from tinygrad.jit import TinyJit, CacheCollector
|
||||||
|
|
||||||
import pyopencl as cl
|
import pyopencl as cl
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
from tinygrad.helpers import prod
|
from tinygrad.helpers import prod
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
|
|
||||||
class TestCopy(unittest.TestCase):
|
class TestCopy(unittest.TestCase):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import unittest, gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn.state import get_state_dict
|
from tinygrad.nn.state import get_state_dict
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||||
from tinygrad.helpers import dtypes, prod
|
from tinygrad.helpers import dtypes, prod
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tinygrad.tensor import Tensor, Device
|
||||||
from tinygrad import nn
|
from tinygrad import nn
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.lazy import PUSH_PERMUTES
|
from tinygrad.lazy import PUSH_PERMUTES
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from weakref import ref
|
from weakref import ref
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||||
from tinygrad.helpers import dtypes, prod
|
from tinygrad.helpers import dtypes, prod
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
|
|
|
@ -2,8 +2,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters, Device
|
from tinygrad.ops import Device
|
||||||
from tinygrad.graph import nm
|
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes
|
||||||
|
|
||||||
N = 200 # has to be bigger than the cache to fail
|
N = 200 # has to be bigger than the cache to fail
|
||||||
|
@ -50,16 +49,6 @@ class TestAssign(unittest.TestCase):
|
||||||
ba2 = a.lazydata.realized
|
ba2 = a.lazydata.realized
|
||||||
# NOTE: don't test that it's assigned
|
# NOTE: don't test that it's assigned
|
||||||
#assert ba1 == ba2 and ba1 != bb1
|
#assert ba1 == ba2 and ba1 != bb1
|
||||||
|
|
||||||
"""
|
|
||||||
if len(GlobalCounters.cache):
|
|
||||||
runner, args = GlobalCounters.cache[0]
|
|
||||||
b0, b1, b2 = args
|
|
||||||
print(nm(b0), id(b0.cl))
|
|
||||||
print(nm(b1), id(b1.cl))
|
|
||||||
print(nm(b2), id(b2.cl))
|
|
||||||
"""
|
|
||||||
|
|
||||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||||
|
|
||||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||||
|
|
|
@ -10,7 +10,7 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
np.set_printoptions(linewidth=160)
|
np.set_printoptions(linewidth=160)
|
||||||
from tinygrad.ops import Device
|
from tinygrad.ops import Device
|
||||||
from tinygrad.ops import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import Conv2d
|
from tinygrad.nn import Conv2d
|
||||||
from tinygrad.helpers import colored, getenv, CI
|
from tinygrad.helpers import colored, getenv, CI
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple
|
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||||
|
import itertools, random
|
||||||
from tinygrad.lazy import vars_from_ast
|
from tinygrad.lazy import vars_from_ast
|
||||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
|
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
|
||||||
|
@ -136,3 +137,15 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||||
return beam[0][0]
|
return beam[0][0]
|
||||||
|
|
||||||
|
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
|
||||||
|
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
||||||
|
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
|
||||||
|
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||||
|
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||||
|
def try_exec(local_size):
|
||||||
|
try:
|
||||||
|
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
|
||||||
|
except Exception:
|
||||||
|
return float('inf')
|
||||||
|
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import importlib, inspect, functools, pathlib, itertools, random
|
import importlib, inspect, functools, pathlib
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
|
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
|
||||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
|
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
|
||||||
|
@ -108,48 +108,49 @@ Device = _Device()
|
||||||
|
|
||||||
# **************** for Interpreted Buffers ****************
|
# **************** for Interpreted Buffers ****************
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
|
||||||
def interpret_ast(device:Interpreted, ast:LazyOp) -> Callable:
|
|
||||||
tglob: Dict[str, Any] = {}
|
|
||||||
lines: List[str] = []
|
|
||||||
f = device.fxn_for_op
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
|
||||||
def gstr(x:Any, nm=None) -> str:
|
|
||||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
|
||||||
tglob[ret] = x
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
|
||||||
def _interpret_ast(ast:LazyOp) -> str:
|
|
||||||
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
|
||||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
|
||||||
|
|
||||||
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
|
|
||||||
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
|
||||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
|
|
||||||
else:
|
|
||||||
inp = [_interpret_ast(src) for src in ast.src]
|
|
||||||
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
|
||||||
|
|
||||||
ret = f"a{len(lines)}"
|
|
||||||
lines.append(f" {ret} = {tmp}")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
ret = _interpret_ast(ast)
|
|
||||||
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(device.from_underlying, 'from_underlying')}({ret})" if device.from_underlying else f" return {ret}"])
|
|
||||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
|
||||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
|
||||||
return tglob['run']
|
|
||||||
|
|
||||||
class Interpreted:
|
class Interpreted:
|
||||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
|
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
|
||||||
self.buffer, self.fxn_for_op, self.to_underlying, self.from_underlying = buffer, fxn_for_op, to_underlying, from_underlying
|
self.buffer, self.fxn_for_op, self.to_underlying, self.from_underlying = buffer, fxn_for_op, to_underlying, from_underlying
|
||||||
self.synchronize = lambda: None
|
self.synchronize = lambda: None
|
||||||
self.codegen = None
|
self.codegen = None
|
||||||
|
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||||
|
|
||||||
|
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
|
||||||
|
tglob: Dict[str, Any] = {}
|
||||||
|
lines: List[str] = []
|
||||||
|
f = self.fxn_for_op
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def gstr(x:Any, nm=None) -> str:
|
||||||
|
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||||
|
tglob[ret] = x
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def _interpret_ast(ast:LazyOp) -> str:
|
||||||
|
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||||
|
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||||
|
|
||||||
|
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
|
||||||
|
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||||
|
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
|
||||||
|
else:
|
||||||
|
inp = [_interpret_ast(src) for src in ast.src]
|
||||||
|
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||||
|
|
||||||
|
ret = f"a{len(lines)}"
|
||||||
|
lines.append(f" {ret} = {tmp}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
ret = _interpret_ast(ast)
|
||||||
|
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
|
||||||
|
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||||
|
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||||
|
return tglob['run']
|
||||||
|
|
||||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
||||||
ret = interpret_ast(self, ast)([x.realized for x in inputs] if inputs else None)
|
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
|
||||||
|
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
|
||||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
|
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
|
||||||
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||||
# TODO: is this used?
|
# TODO: is this used?
|
||||||
|
@ -188,19 +189,6 @@ class ASTRunner:
|
||||||
if DEBUG >= 4: print(prg)
|
if DEBUG >= 4: print(prg)
|
||||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||||
|
|
||||||
def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
|
|
||||||
assert self.global_size is not None, "needs a global size to optimize local size"
|
|
||||||
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
|
||||||
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
|
|
||||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
|
||||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
|
||||||
def try_exec(local_size):
|
|
||||||
try:
|
|
||||||
return self.clprg([g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size, *test_rawbuffers, wait=True)
|
|
||||||
except Exception:
|
|
||||||
return float('inf')
|
|
||||||
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
|
||||||
|
|
||||||
def build(self, compiler, runtime):
|
def build(self, compiler, runtime):
|
||||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||||
self.clprg = runtime(self.name, self.lib)
|
self.clprg = runtime(self.name, self.lib)
|
||||||
|
@ -221,7 +209,8 @@ class ASTRunner:
|
||||||
global_size, local_size = self.launch_dims(var_vals)
|
global_size, local_size = self.launch_dims(var_vals)
|
||||||
if global_size is not None and local_size is None:
|
if global_size is not None and local_size is None:
|
||||||
# TODO: this is copied from get_program
|
# TODO: this is copied from get_program
|
||||||
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs)
|
from tinygrad.features.search import optimize_local_size
|
||||||
|
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
|
||||||
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||||
lra = self.runtime_args.copy()
|
lra = self.runtime_args.copy()
|
||||||
if global_size: lra['global_size'] = global_size
|
if global_size: lra['global_size'] = global_size
|
||||||
|
|
Loading…
Reference in New Issue