1
0
Fork 0

jit doesn't use named tensors (#2393)

* jit doesn't use named tensors

* move to compile2

* remove broken single root junk

* explicit float32

* skip slow test
pull/2396/head
George Hotz 2023-11-23 00:13:18 -08:00 committed by GitHub
parent 80e4ad8bf5
commit 8656eebb42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 36 additions and 186 deletions

View File

@ -178,17 +178,14 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32, new compiler)
run: DEBUGCL=1 FLOAT16=0 python3 openpilot/compile2.py
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot alt model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test tensor core ops
run: GPU=1 TC=2 python -m pytest -n=auto test/test_ops.py

View File

@ -1,150 +0,0 @@
#!/usr/bin/env python3
import os, time, io, pathlib, sys, traceback, re
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if os.getenv("OPT", None) is None:
os.environ['OPT'] = '99'
if os.getenv("GPU", None) is None:
os.environ['GPU'] = '1'
if os.getenv("IMAGE", None) is None:
os.environ['IMAGE'] = '2'
from tinygrad.helpers import getenv, dtypes
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
DEBUGCL = getenv("DEBUGCL", 0)
import onnx
import numpy as np
import tinygrad.graph as graph
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, CacheCollector
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL
from extra.utils import fetch
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
np.random.seed(1337)
def get_random_input_tensors(input_shapes):
# this 16 is a random scale factor
inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
return inputs, np_inputs
@TinyJit
def model_exec(run_onnx, using_graph, **inputs):
ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32)
GlobalCounters.reset()
CacheCollector.start() # don't cache pre-realize
if using_graph: graph.GRAPH = True
print("realizing")
return ret.realize()
def compile(dat, output_fn):
Tensor.manual_seed(1337)
Tensor.no_grad = True
using_graph = graph.GRAPH
if getenv("GRAPH") < 3: graph.GRAPH = False
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
inputs, np_inputs = get_random_input_tensors(input_shapes)
# run twice to trigger the JIT
for i in range(2): tinygrad_out = model_exec(run_onnx, i == 1 and using_graph, **inputs)
graph.GRAPH = False
print("kernel count:", len(model_exec.jit_cache))
assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!"
# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
# transform to CL.CACHE
used_ops = 0
cl_cache = []
for ji in model_exec.jit_cache:
prg = ji.prg
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
if getenv("VALIDTEST") == 1:
src = re.search(r"=.*\?.*?read_image", prg.prg)
if src is not None: raise Exception("Openpilot has valid checks!")
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in ji.rawbufs]]))
used_ops += prg.op_estimate
from extra.thneed import Thneed
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
# save thneed (before run)
t.save(output_fn)
print(f"buffers to save: {len(t.buffers_to_save)}, inputs: {list(t.inputs.keys())}, outputs: {t.outputs}")
runtime = t.run()
print(f"network using {used_ops/1e9:.2f} GOPS with runtime {runtime*1e3:.2f} ms that's {used_ops/runtime*1e-9:.2f} GFLOPS")
# confirm thneed found the right output
thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue[0], thneed_out, t.outputs[0], is_blocking=True)
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
# testing is float32 only (fix this)
FLOAT16 = getenv("FLOAT16", 0)
if FLOAT16 == 0:
try:
from test.models.test_onnx import run_onnx_torch
torch_out = run_onnx_torch(onnx_model, np_inputs).numpy()
print(thneed_out, torch_out, "mse", np.sum((thneed_out-torch_out)**2), "max err", np.max(np.abs((thneed_out-torch_out))))
np.testing.assert_allclose(torch_out, thneed_out, atol=1e-4, rtol=1e-2)
# test loading/run thneed
_, new_np_inputs = get_random_input_tensors(input_shapes)
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
# try old thneed with a different input
for k,v in t.inputs.items():
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
t.run()
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue[0], old_thneed_out, t.outputs[0], is_blocking=True)
# compare thneed (rerun) with torch
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
# load thneed and try that
_, new_np_inputs = get_random_input_tensors(input_shapes)
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
nt = Thneed()
nt.load(output_fn)
# inputs
for k,v in nt.inputs.items():
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
nt.run()
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
# compare torch to thneed
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
print("thneed self-test passed!")
except ModuleNotFoundError as e:
print(f"TEST NOT HAPPENING {e}")
# UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 python3 openpilot/compile.py
# 22.59 ms
if __name__ == "__main__":
dat = fetch(OPENPILOT_MODEL if len(sys.argv) == 1 else sys.argv[1])
compile(dat, sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed")

View File

@ -1,12 +1,11 @@
#!/usr/bin/env python3
import os, sys, io, pathlib
import os, sys, io, pathlib, re
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "OPT" not in os.environ: os.environ["OPT"] = "99"
os.environ["PREREALIZE"] = "0"
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
@ -55,6 +54,9 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
def schedule_to_thneed(schedule, output_fn):
from extra.thneed import Thneed
print("kernel count:", len(schedule))
assert len(schedule) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
# transform to CL.CACHE
used_ops = 0
cl_cache = []
@ -66,6 +68,10 @@ def schedule_to_thneed(schedule, output_fn):
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
if getenv("VALIDTEST") == 1:
src = re.search(r"=.*\?.*?read_image", prg.prg)
if src is not None: raise Exception("Openpilot has valid checks!")
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))

View File

@ -1,2 +1,2 @@
#!/bin/bash
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile2.py

View File

@ -496,6 +496,7 @@ class TestLinearizerOpts(unittest.TestCase):
def test_padto_matmul(self):
if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer")
if Device.DEFAULT == "CUDA": self.skipTest("super slow on CUDA/triton")
N = 17 * 17
Tensor.manual_seed(289)
a = Tensor.rand(N, N)

View File

@ -3,8 +3,6 @@ from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, fla
# *** image Tensor function replacements ***
from tinygrad.lazy import get_single_root
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
n1, n2 = len(self.shape), len(w.shape)
@ -60,7 +58,6 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
from tinygrad.codegen.linearizer import UOps
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import NumNode
@ -107,7 +107,7 @@ def _tree(lazydata, prefix=""):
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
def graph_uops(uops):
def graph_uops(uops:List[UOp]):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}

View File

@ -20,12 +20,12 @@ class TinyJit(Generic[ReturnType]):
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None
@property
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
@property
def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {}
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {}
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
@ -33,11 +33,11 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()])
# get rawbuffers
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
input_rawbuffers: List[RawBuffer] = [cast(RawBuffer, v.lazydata.realized) for v in input_tensors.values()]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
@ -45,11 +45,11 @@ class TinyJit(Generic[ReturnType]):
if self.cnt >= 2:
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_sts_dtype == expected_sts_dtype, f"mismatch of sts, expected {self.expected_sts_dtype} got {expected_sts_dtype}"
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
assert self.jit_fxn, "didn't get jitted?"
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
elif self.cnt == 1:
self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)

View File

@ -76,7 +76,6 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
# **** lazy operations ****
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)

View File

@ -37,7 +37,7 @@ class MemBuffer:
@dataclass(frozen=True)
class ConstBuffer:
val: Any
val: Union[int, float]
dtype: DType
st: ShapeTracker
@ -152,22 +152,22 @@ class JitItem:
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.input_replace: Dict[Tuple[int, int], int] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
if a in input_rawbuffers:
self.input_replace[(j,i)] = input_rawbuffers.index(a)
assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
self.clear_jit_inputs()

View File

@ -1,7 +1,7 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes, tempfile
import Metal, libdispatch
from typing import List, Any, Tuple, Dict, Union, Set, cast
from typing import List, Any, Tuple, Dict, Set, cast
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
@ -85,7 +85,7 @@ class MetalProgram:
METAL.mtl_buffers_in_flight.append(command_buffer)
class MetalBatchExecutor(BatchExecutor):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# create metal batch exec
@ -127,12 +127,12 @@ class MetalBatchExecutor(BatchExecutor):
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False):
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
for (j,i),input_name in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.input_has_variable_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))