Fix toyota_eps_factor.py script (#2647)
* fix toyota_eps_factor.py script, don't use samples where wheel is touched * caps mistake * Hope this is correct, don't use CS.steeringPressed * revert sorting now that we don't use CSalbatross
parent
ee43eb552b
commit
460e4dc3b0
|
@ -2,22 +2,24 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn import linear_model # pylint: disable=import-error
|
||||
from sklearn import linear_model # pylint: disable=import-error
|
||||
from selfdrive.car.toyota.values import STEER_THRESHOLD
|
||||
|
||||
from tools.lib.route import Route
|
||||
from tools.lib.logreader import MultiLogIterator
|
||||
|
||||
MIN_SAMPLES = 30 * 100
|
||||
|
||||
MIN_SAMPLES = 30*100
|
||||
|
||||
def to_signed(n, bits):
|
||||
if n >= (1 << max((bits - 1), 0)):
|
||||
n = n - (1 << max(bits, 0))
|
||||
return n
|
||||
|
||||
def get_eps_factor(lr, plot=False):
|
||||
|
||||
def get_eps_factor(lr, plot=False):
|
||||
engaged = False
|
||||
steering_pressed = False
|
||||
torque_cmd, eps_torque = None, None
|
||||
cmds, eps = [], []
|
||||
|
||||
|
@ -31,8 +33,9 @@ def get_eps_factor(lr, plot=False):
|
|||
torque_cmd = to_signed((m.dat[1] << 8) | m.dat[2], 16)
|
||||
elif m.address == 0x260 and m.src == 0:
|
||||
eps_torque = to_signed((m.dat[5] << 8) | m.dat[6], 16)
|
||||
steering_pressed = abs(to_signed((m.dat[1] << 8) | m.dat[2], 16)) > STEER_THRESHOLD
|
||||
|
||||
if engaged and torque_cmd is not None and eps_torque is not None:
|
||||
if engaged and torque_cmd is not None and eps_torque is not None and not steering_pressed:
|
||||
cmds.append(torque_cmd)
|
||||
eps.append(eps_torque)
|
||||
else:
|
||||
|
@ -45,14 +48,15 @@ def get_eps_factor(lr, plot=False):
|
|||
|
||||
lm = linear_model.LinearRegression(fit_intercept=False)
|
||||
lm.fit(np.array(cmds).reshape(-1, 1), eps)
|
||||
scale_factor = 1./lm.coef_[0]
|
||||
scale_factor = 1. / lm.coef_[0]
|
||||
|
||||
if plot:
|
||||
plt.plot(np.array(eps)*scale_factor)
|
||||
plt.plot(np.array(eps) * scale_factor)
|
||||
plt.plot(cmds)
|
||||
plt.show()
|
||||
return scale_factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r = Route(sys.argv[1])
|
||||
lr = MultiLogIterator(r.log_paths(), wraparound=False)
|
||||
|
|
Loading…
Reference in New Issue