model: mse err from 0.02-> 0.000056 (#23891)

* mse err from 0.028070712 -> 5.8073703e-05

* build with weights fixup

* need thneed lib also

* don't break for binaries

* static analysis says i need init

* check the bias

* load_dlc_weights

* nicer scons

* tested scons

* fix static

* pylint issue

* new ref

* a few more asserts

Co-authored-by: Harald Schafer <harald.the.engineer@gmail.com>
pull/23899/head
George Hotz 2022-03-02 20:52:17 -08:00 committed by GitHub
parent 77fd64ee30
commit 8d6f49aecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 188 additions and 3 deletions

View File

@ -424,6 +424,7 @@ selfdrive/modeld/transforms/transform.cc
selfdrive/modeld/transforms/transform.h
selfdrive/modeld/transforms/transform.cl
selfdrive/modeld/thneed/*.py
selfdrive/modeld/thneed/thneed.*
selfdrive/modeld/thneed/serialize.cc
selfdrive/modeld/thneed/compile.cc

View File

@ -67,14 +67,22 @@ common_model = lenv.Object(common_src)
if use_thneed and arch in ("aarch64", "larch64"):
fn = File("#models/supercombo").abspath
compiler = lenv.Program('thneed/compile', ["thneed/compile.cc"]+common_model, LIBS=libs)
cmd = f"cd {Dir('.').abspath} && {compiler[0].abspath} {fn}.dlc {fn}.thneed --binary"
cmd = f"cd {Dir('.').abspath} && {compiler[0].abspath} {fn}.dlc {fn}_badweights.thneed --binary"
lib_paths = ':'.join(Dir(p).abspath for p in lenv["LIBPATH"])
kernel_path = os.path.join(Dir('.').abspath, "thneed", "kernels")
cenv = Environment(ENV={'LD_LIBRARY_PATH': f"{lib_paths}:{lenv['ENV']['LD_LIBRARY_PATH']}", 'KERNEL_PATH': kernel_path})
kernels = [os.path.join(kernel_path, x) for x in os.listdir(kernel_path) if x.endswith(".cl")]
cenv.Command(fn + ".thneed", [fn + ".dlc", kernels, compiler], cmd)
cenv.Command(fn + "_badweights.thneed", [fn + ".dlc", kernels, compiler], cmd)
from selfdrive.modeld.thneed.weights_fixup import weights_fixup
def weights_fixup_action(target, source, env):
weights_fixup(target[0].abspath, source[0].abspath, source[1].abspath)
env = Environment(BUILDERS = {'WeightFixup' : Builder(action = weights_fixup_action)})
env.WeightFixup(target=fn + ".thneed", source=[fn+"_badweights.thneed", fn+".dlc"])
lenv.Program('_dmonitoringmodeld', [
"dmonitoringmodeld.cc",

View File

@ -0,0 +1,31 @@
import struct, json
def load_thneed(fn):
with open(fn, "rb") as f:
json_len = struct.unpack("I", f.read(4))[0]
jdat = json.loads(f.read(json_len).decode('latin_1'))
weights = f.read()
ptr = 0
for o in jdat['objects']:
if o['needs_load']:
nptr = ptr + o['size']
o['data'] = weights[ptr:nptr]
ptr = nptr
for o in jdat['binaries']:
nptr = ptr + o['length']
o['data'] = weights[ptr:nptr]
ptr = nptr
return jdat
def save_thneed(jdat, fn):
new_weights = []
for o in jdat['objects'] + jdat['binaries']:
if 'data' in o:
new_weights.append(o['data'])
del o['data']
new_weights = b''.join(new_weights)
with open(fn, "wb") as f:
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
f.write(struct.pack("I", len(j)))
f.write(j)
f.write(new_weights)

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
import os
import struct
import zipfile
import numpy as np
from tqdm import tqdm
from common.basedir import BASEDIR
from selfdrive.modeld.thneed.lib import load_thneed, save_thneed
# this is junk code, but it doesn't have deps
def load_dlc_weights(fn):
archive = zipfile.ZipFile(fn, 'r')
dlc_params = archive.read("model.params")
def extract(rdat):
idx = rdat.find(b"\x00\x00\x00\x09\x04\x00\x00\x00")
rdat = rdat[idx+8:]
ll = struct.unpack("I", rdat[0:4])[0]
buf = np.frombuffer(rdat[4:4+ll*4], dtype=np.float32)
rdat = rdat[4+ll*4:]
dims = struct.unpack("I", rdat[0:4])[0]
buf = buf.reshape(struct.unpack("I"*dims, rdat[4:4+dims*4]))
if len(buf.shape) == 4:
buf = np.transpose(buf, (3,2,0,1))
return buf
def parse(tdat):
ll = struct.unpack("I", tdat[0:4])[0] + 4
return (None, [extract(tdat[0:]), extract(tdat[ll:])])
ptr = 0x20
def r4():
nonlocal ptr
ret = struct.unpack("I", dlc_params[ptr:ptr+4])[0]
ptr += 4
return ret
ranges = []
cnt = r4()
for _ in range(cnt):
o = r4() + ptr
# the header is 0xC
plen, is_4, is_2 = struct.unpack("III", dlc_params[o:o+0xC])
assert is_4 == 4 and is_2 == 2
ranges.append((o+0xC, o+plen+0xC))
ranges = sorted(ranges, reverse=True)
return [parse(dlc_params[s:e]) for s,e in ranges]
# this won't run on device without onnx
def load_onnx_weights(fn):
import onnx
from onnx import numpy_helper
model = onnx.load(fn)
graph = model.graph # pylint: disable=maybe-no-member
init = {x.name:x for x in graph.initializer}
onnx_layers = []
for node in graph.node:
#print(node.name, node.op_type, node.input, node.output)
vals = []
for inp in node.input:
if inp in init:
vals.append(numpy_helper.to_array(init[inp]))
if len(vals) > 0:
onnx_layers.append((node.name, vals))
return onnx_layers
def weights_fixup(target, source_thneed, dlc):
#onnx_layers = load_onnx_weights(os.path.join(BASEDIR, "models/supercombo.onnx"))
onnx_layers = load_dlc_weights(dlc)
jdat = load_thneed(source_thneed)
bufs = {}
for o in jdat['objects']:
bufs[o['id']] = o
thneed_layers = []
for k in jdat['kernels']:
#print(k['name'])
vals = []
for a in k['args']:
if a in bufs:
o = bufs[a]
if o['needs_load'] or ('buffer_id' in o and bufs[o['buffer_id']]['needs_load']):
#print(" ", o['arg_type'])
vals.append(o)
if len(vals) > 0:
thneed_layers.append((k['name'], vals))
assert len(thneed_layers) == len(onnx_layers)
# fix up weights
for tl, ol in tqdm(zip(thneed_layers, onnx_layers), total=len(thneed_layers)):
#print(tl[0], ol[0])
assert len(tl[1]) == len(ol[1])
for o, onnx_weight in zip(tl[1], ol[1]):
if o['arg_type'] == "image2d_t":
obuf = bufs[o['buffer_id']]
saved_weights = np.frombuffer(obuf['data'], dtype=np.float16).reshape(o['height'], o['row_pitch']//2)
if len(onnx_weight.shape) == 4:
# convolution
oc,ic,ch,cw = onnx_weight.shape
if 'depthwise' in tl[0]:
assert ic == 1
weights = np.transpose(onnx_weight.reshape(oc//4,4,ch,cw), (0,2,3,1)).reshape(o['height'], o['width']*4)
else:
weights = np.transpose(onnx_weight.reshape(oc//4,4,ic//4,4,ch,cw), (0,4,2,5,1,3)).reshape(o['height'], o['width']*4)
else:
# fc_Wtx
weights = onnx_weight
new_weights = np.zeros((o['height'], o['row_pitch']//2), dtype=np.float32)
new_weights[:, :weights.shape[1]] = weights
# weights shouldn't be too far off
err = np.mean((saved_weights.astype(np.float32) - new_weights)**2)
assert err < 1e-3
rerr = np.mean(np.abs((saved_weights.astype(np.float32) - new_weights)/(new_weights+1e-12)))
assert rerr < 0.5
# fix should improve things
fixed_err = np.mean((new_weights.astype(np.float16).astype(np.float32) - new_weights)**2)
assert (err/fixed_err) >= 1
#print(" ", o['size'], onnx_weight.shape, o['row_pitch'], o['width'], o['height'], "err %.2fx better" % (err/fixed_err))
obuf['data'] = new_weights.astype(np.float16).tobytes()
elif o['arg_type'] == "float*":
# unconverted floats are correct
new_weights = np.zeros(o['size']//4, dtype=np.float32)
new_weights[:onnx_weight.shape[0]] = onnx_weight
assert new_weights.tobytes() == o['data']
#print(" ", o['size'], onnx_weight.shape)
save_thneed(jdat, target)
if __name__ == "__main__":
weights_fixup(os.path.join(BASEDIR, "models/supercombo_fixed.thneed"),
os.path.join(BASEDIR, "models/supercombo.thneed"),
os.path.join(BASEDIR, "models/supercombo.dlc"))

View File

@ -1 +1 @@
19720e79b1c5136a882efd689651d9044e2e2007
15821a7f867f6b497a17e8a36c9d42ad548acacd