Add type hints, small cleanups (#21080)

* improve tools.lib.kbhit and tools.sim.lib.keyboard_ctrl

* unpack more efficiently

* minor improvements

* agnos.py match spec better

* manual_ctrl test missing queue arg

* fix incorrect type annotation

* queues are generic

* varname reuse resulting in incorrect type inference

* bytes().hex() rather than bytes.hex(bytes())

* a bit of type hinting stuff
pull/21120/head
Josh Smith 2021-06-03 06:21:04 -04:00 committed by GitHub
parent 8220056252
commit 77321dbac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 130 additions and 122 deletions

View File

@ -3,6 +3,7 @@ import gc
import os
import time
import multiprocessing
from typing import Optional
from common.clock import sec_since_boot # pylint: disable=no-name-in-module, import-error
from selfdrive.hardware import PC, TICI
@ -31,49 +32,49 @@ class Priority:
CTRL_HIGH = 53
def set_realtime_priority(level):
def set_realtime_priority(level: int) -> None:
if not PC:
os.sched_setscheduler(0, os.SCHED_FIFO, os.sched_param(level))
os.sched_setscheduler(0, os.SCHED_FIFO, os.sched_param(level)) # type: ignore[attr-defined]
def set_core_affinity(core):
def set_core_affinity(core: int) -> None:
if not PC:
os.sched_setaffinity(0, [core,])
def config_realtime_process(core, priority):
def config_realtime_process(core: int, priority: int) -> None:
gc.disable()
set_realtime_priority(priority)
set_core_affinity(core)
class Ratekeeper():
def __init__(self, rate, print_delay_threshold=0.):
class Ratekeeper:
def __init__(self, rate: int, print_delay_threshold: Optional[float] = 0.0) -> None:
"""Rate in Hz for ratekeeping. print_delay_threshold must be nonnegative."""
self._interval = 1. / rate
self._next_frame_time = sec_since_boot() + self._interval
self._print_delay_threshold = print_delay_threshold
self._frame = 0
self._remaining = 0
self._remaining = 0.0
self._process_name = multiprocessing.current_process().name
@property
def frame(self):
def frame(self) -> int:
return self._frame
@property
def remaining(self):
def remaining(self) -> float:
return self._remaining
# Maintain loop rate by calling this at the end of each loop
def keep_time(self):
def keep_time(self) -> bool:
lagged = self.monitor_time()
if self._remaining > 0:
time.sleep(self._remaining)
return lagged
# this only monitor the cumulative lag, but does not enforce a rate
def monitor_time(self):
def monitor_time(self) -> bool:
lagged = False
remaining = self._next_frame_time - sec_since_boot()
self._next_frame_time += self._interval

View File

@ -1,6 +1,7 @@
import time
from collections import defaultdict
from functools import partial
from typing import Optional
import cereal.messaging as messaging
from selfdrive.swaglog import cloudlog
@ -8,7 +9,7 @@ from selfdrive.boardd.boardd import can_list_to_can_capnp
from panda.python.uds import CanClient, IsoTpMessage, FUNCTIONAL_ADDRS, get_rx_addr_for_tx_addr
class IsoTpParallelQuery():
class IsoTpParallelQuery:
def __init__(self, sendcan, logcan, bus, addrs, request, response, response_offset=0x8, functional_addr=False, debug=False):
self.sendcan = sendcan
self.logcan = logcan
@ -103,7 +104,7 @@ class IsoTpParallelQuery():
break
for tx_addr, msg in msgs.items():
dat = msg.recv()
dat: Optional[bytes] = msg.recv()
if not dat:
continue
@ -121,7 +122,7 @@ class IsoTpParallelQuery():
request_done[tx_addr] = True
else:
request_done[tx_addr] = True
cloudlog.warning(f"iso-tp query bad response: 0x{bytes.hex(dat)}")
cloudlog.warning(f"iso-tp query bad response: 0x{dat.hex()}")
if time.time() - start_time > timeout:
break

View File

@ -5,7 +5,7 @@ from selfdrive.version import version
import sentry_sdk
from sentry_sdk.integrations.threading import ThreadingIntegration
def capture_exception(*args, **kwargs):
def capture_exception(*args, **kwargs) -> None:
cloudlog.error("crash", exc_info=kwargs.get('exc_info', 1))
try:
@ -14,14 +14,14 @@ def capture_exception(*args, **kwargs):
except Exception:
cloudlog.exception("sentry exception")
def bind_user(**kwargs):
def bind_user(**kwargs) -> None:
sentry_sdk.set_user(kwargs)
def bind_extra(**kwargs):
def bind_extra(**kwargs) -> None:
for k, v in kwargs.items():
sentry_sdk.set_tag(k, v)
def init():
def init() -> None:
sentry_sdk.init("https://a8dc76b5bfb34908a601d67e2aa8bcf9@o33823.ingest.sentry.io/77924",
default_integrations=False, integrations=[ThreadingIntegration(propagate_hub=True)],
release=version)

View File

@ -6,12 +6,13 @@ import requests
import struct
import subprocess
import os
from typing import Generator
from common.spinner import Spinner
class StreamingDecompressor:
def __init__(self, url):
def __init__(self, url: str) -> None:
self.buf = b""
self.req = requests.get(url, stream=True, headers={'Accept-Encoding': None})
@ -20,7 +21,7 @@ class StreamingDecompressor:
self.eof = False
self.sha256 = hashlib.sha256()
def read(self, length):
def read(self, length: int) -> bytes:
while len(self.buf) < length:
self.req.raise_for_status()
@ -38,8 +39,9 @@ class StreamingDecompressor:
self.sha256.update(result)
return result
def unsparsify(f):
SPARSE_CHUNK_FMT = struct.Struct('H2xI4x')
def unsparsify(f: StreamingDecompressor) -> Generator[bytes, None, None]:
# https://source.android.com/devices/bootloader/images#sparse-format
magic = struct.unpack("I", f.read(4))[0]
assert(magic == 0xed26ff3a)
@ -48,20 +50,16 @@ def unsparsify(f):
minor = struct.unpack("H", f.read(2))[0]
assert(major == 1 and minor == 0)
# Header sizes
_ = struct.unpack("H", f.read(2))[0]
_ = struct.unpack("H", f.read(2))[0]
f.read(2) # file header size
f.read(2) # chunk header size
block_sz = struct.unpack("I", f.read(4))[0]
_ = struct.unpack("I", f.read(4))[0]
f.read(4) # total blocks
num_chunks = struct.unpack("I", f.read(4))[0]
_ = struct.unpack("I", f.read(4))[0]
f.read(4) # crc checksum
for _ in range(num_chunks):
chunk_type = struct.unpack("H", f.read(2))[0]
_ = struct.unpack("H", f.read(2))[0]
out_blocks = struct.unpack("I", f.read(4))[0]
_ = struct.unpack("I", f.read(4))[0]
chunk_type, out_blocks = SPARSE_CHUNK_FMT.unpack(f.read(12))
if chunk_type == 0xcac1: # Raw
# TODO: yield in smaller chunks. Yielding only block_sz is too slow. Largest observed data chunk is 252 MB.

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3
import zmq
from typing import NoReturn
import cereal.messaging as messaging
from common.logging_extra import SwagLogFileFormatter
from selfdrive.swaglog import get_file_handler
def main():
def main() -> NoReturn:
log_handler = get_file_handler()
log_handler.setFormatter(SwagLogFileFormatter(None))
log_level = 20 # logging.INFO

View File

@ -10,7 +10,7 @@ from selfdrive.swaglog import cloudlog
PANDA_FW_FN = os.path.join(PANDA_BASEDIR, "board", "obj", "panda.bin.signed")
def get_expected_signature():
def get_expected_signature() -> bytes:
try:
return Panda.get_signature_from_firmware(PANDA_FW_FN)
except Exception:
@ -18,7 +18,7 @@ def get_expected_signature():
return b""
def update_panda():
def update_panda() -> None:
panda = None
panda_dfu = None
@ -81,7 +81,7 @@ def update_panda():
panda.reset()
def main():
def main() -> None:
update_panda()
os.chdir(os.path.join(BASEDIR, "selfdrive/boardd"))

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3
import os
import time
from typing import NoReturn
from common.realtime import set_core_affinity, set_realtime_priority
# RT shield - ensure CPU 3 always remains available for RT processes
@ -8,7 +10,7 @@ from common.realtime import set_core_affinity, set_realtime_priority
# get scheduled onto CPU 3, but it's always preemptible by realtime
# openpilot processes
def main():
def main() -> NoReturn:
set_core_affinity(int(os.getenv("CORE", "3")))
set_realtime_priority(1)
@ -17,4 +19,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -150,7 +150,7 @@ class Plant():
# lead car
self.distance_lead, self.distance_lead_prev = distance_lead , distance_lead
self.rk = Ratekeeper(rate, print_delay_threshold=100)
self.rk = Ratekeeper(rate, print_delay_threshold=100.0)
self.ts = 1./rate
self.cp = get_car_can_parser()

View File

@ -4,40 +4,45 @@ import termios
import atexit
from select import select
STDIN_FD = sys.stdin.fileno()
class KBHit:
def __init__(self):
'''Creates a KBHit object that you can call to do various keyboard things.
def __init__(self) -> None:
''' Creates a KBHit object that you can call to do various keyboard things.
'''
self.set_kbhit_terminal()
def set_kbhit_terminal(self):
def set_kbhit_terminal(self) -> None:
''' Save old terminal settings for closure, remove ICANON & ECHO flags.
'''
# Save the terminal settings
self.fd = sys.stdin.fileno()
self.new_term = termios.tcgetattr(self.fd)
self.old_term = termios.tcgetattr(self.fd)
self.old_term = termios.tcgetattr(STDIN_FD)
self.new_term = self.old_term.copy()
# New terminal setting unbuffered
self.new_term[3] = (self.new_term[3] & ~termios.ICANON & ~termios.ECHO)
termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.new_term)
self.new_term[3] &= ~(termios.ICANON | termios.ECHO) # type: ignore
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.new_term)
# Support normal-terminal reset at exit
atexit.register(self.set_normal_term)
def set_normal_term(self):
''' Resets to normal terminal. On Windows this is a no-op.
def set_normal_term(self) -> None:
''' Resets to normal terminal. On Windows this is a no-op.
'''
termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_term)
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.old_term)
def getch(self):
@staticmethod
def getch() -> str:
''' Returns a keyboard character after kbhit() has been called.
Should not be called in the same program as getarrow().
'''
return sys.stdin.read(1)
def getarrow(self):
@staticmethod
def getarrow() -> int:
''' Returns an arrow-key code after kbhit() has been called. Codes are
0 : up
1 : right
@ -49,13 +54,13 @@ class KBHit:
c = sys.stdin.read(3)[2]
vals = [65, 67, 66, 68]
return vals.index(ord(c.decode('utf-8')))
return vals.index(ord(c))
def kbhit(self):
@staticmethod
def kbhit():
''' Returns True if keyboard character was hit, False otherwise.
'''
dr, _, _ = select([sys.stdin], [], [], 0)
return dr != []
return select([sys.stdin], [], [], 0)[0] != []
# Test
@ -69,7 +74,7 @@ if __name__ == "__main__":
if kb.kbhit():
c = kb.getch()
if ord(c) == 27: # ESC
if c == '\x1b': # ESC
break
print(c)

View File

@ -9,7 +9,7 @@ class _FrameReaderDict(dict):
if cache_paths is None:
cache_paths = {}
if not isinstance(cache_paths, dict):
cache_paths = {k: v for k, v in enumerate(cache_paths)}
cache_paths = dict(enumerate(cache_paths))
self._camera_paths = camera_paths
self._cache_paths = cache_paths

View File

@ -63,7 +63,7 @@ if __name__ == "__main__":
msg = messaging.recv_sock(s)
#msg = messaging.recv_one_or_none(s)
if msg is not None:
x[i] = np.append(x[i], getattr(msg, 'logMonoTime') / float(1e9))
x[i] = np.append(x[i], getattr(msg, 'logMonoTime') / 1e9)
x[i] = np.delete(x[i], 0)
y[i] = np.append(y[i], recursive_getattr(msg, subs_name[i]))
y[i] = np.delete(y[i], 0)

View File

@ -248,9 +248,9 @@ def bridge(q):
# 3. Send current carstate to op via can
cruise_button = 0
throttle_out = steer_out = brake_out = 0
throttle_out = steer_out = brake_out = 0.0
throttle_op = steer_op = brake_op = 0
throttle_manual = steer_manual = brake_manual = 0
throttle_manual = steer_manual = brake_manual = 0.0
# --------------Step 1-------------------------------
if not q.empty():
@ -259,24 +259,24 @@ def bridge(q):
if m[0] == "steer":
steer_manual = float(m[1])
is_openpilot_engaged = False
if m[0] == "throttle":
elif m[0] == "throttle":
throttle_manual = float(m[1])
is_openpilot_engaged = False
if m[0] == "brake":
elif m[0] == "brake":
brake_manual = float(m[1])
is_openpilot_engaged = False
if m[0] == "reverse":
elif m[0] == "reverse":
#in_reverse = not in_reverse
cruise_button = CruiseButtons.CANCEL
is_openpilot_engaged = False
if m[0] == "cruise":
elif m[0] == "cruise":
if m[1] == "down":
cruise_button = CruiseButtons.DECEL_SET
is_openpilot_engaged = True
if m[1] == "up":
elif m[1] == "up":
cruise_button = CruiseButtons.RES_ACCEL
is_openpilot_engaged = True
if m[1] == "cancel":
elif m[1] == "cancel":
cruise_button = CruiseButtons.CANCEL
is_openpilot_engaged = False

View File

@ -3,7 +3,7 @@ import termios
import time
from termios import (BRKINT, CS8, CSIZE, ECHO, ICANON, ICRNL, IEXTEN, INPCK,
ISTRIP, IXON, PARENB, VMIN, VTIME)
from typing import Any
from typing import NoReturn
# Indexes for termios list.
IFLAG = 0
@ -14,55 +14,56 @@ ISPEED = 4
OSPEED = 5
CC = 6
def getch():
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
STDIN_FD = sys.stdin.fileno()
def getch() -> str:
old_settings = termios.tcgetattr(STDIN_FD)
try:
# set
mode = termios.tcgetattr(fd)
mode[IFLAG] = mode[IFLAG] & ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON)
#mode[OFLAG] = mode[OFLAG] & ~(OPOST)
mode[CFLAG] = mode[CFLAG] & ~(CSIZE | PARENB)
mode[CFLAG] = mode[CFLAG] | CS8
mode[LFLAG] = mode[LFLAG] & ~(ECHO | ICANON | IEXTEN)
mode = old_settings.copy()
mode[IFLAG] &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON)
#mode[OFLAG] &= ~(OPOST)
mode[CFLAG] &= ~(CSIZE | PARENB)
mode[CFLAG] |= CS8
mode[LFLAG] &= ~(ECHO | ICANON | IEXTEN)
mode[CC][VMIN] = 1
mode[CC][VTIME] = 0
termios.tcsetattr(fd, termios.TCSAFLUSH, mode)
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, mode)
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
termios.tcsetattr(STDIN_FD, termios.TCSADRAIN, old_settings)
return ch
def keyboard_poll_thread(q):
def keyboard_poll_thread(q: 'Queue[str]') -> NoReturn:
while True:
c = getch()
# print("got %s" % c)
if c == '1':
q.put(str("cruise_up"))
if c == '2':
q.put(str("cruise_down"))
if c == '3':
q.put(str("cruise_cancel"))
if c == 'w':
q.put(str("throttle_%f" % 1.0))
if c == 'a':
q.put(str("steer_%f" % 0.15))
if c == 's':
q.put(str("brake_%f" % 1.0))
if c == 'd':
q.put(str("steer_%f" % -0.15))
if c == 'q':
q.put("cruise_up")
elif c == '2':
q.put("cruise_down")
elif c == '3':
q.put("cruise_cancel")
elif c == 'w':
q.put("throttle_%f" % 1.0)
elif c == 'a':
q.put("steer_%f" % 0.15)
elif c == 's':
q.put("brake_%f" % 1.0)
elif c == 'd':
q.put("steer_%f" % -0.15)
elif c == 'q':
exit(0)
def test(q):
while 1:
print("hello")
time.sleep(1.0)
def test(q: 'Queue[str]') -> NoReturn:
while True:
print([q.get_nowait() for _ in range(q.qsize())] or None)
time.sleep(0.25)
if __name__ == '__main__':
from multiprocessing import Process, Queue
q : Any = Queue()
q: Queue[str] = Queue()
p = Process(target=test, args=(q,))
p.daemon = True
p.start()

View File

@ -4,6 +4,7 @@ import array
import os
import struct
from fcntl import ioctl
from typing import NoReturn
# Iterate over the joystick devices.
print('Available devices:')
@ -90,7 +91,7 @@ button_names = {
axis_map = []
button_map = []
def wheel_poll_thread(q):
def wheel_poll_thread(q: 'Queue[str]') -> NoReturn:
# Open the joystick device.
fn = '/dev/input/js0'
print('Opening %s...' % fn)
@ -116,8 +117,8 @@ def wheel_poll_thread(q):
buf = array.array('B', [0] * 0x40)
ioctl(jsdev, 0x80406a32, buf) # JSIOCGAXMAP
for axis in buf[:num_axes]:
axis_name = axis_names.get(axis, 'unknown(0x%02x)' % axis)
for _axis in buf[:num_axes]:
axis_name = axis_names.get(_axis, 'unknown(0x%02x)' % _axis)
axis_map.append(axis_name)
axis_states[axis_name] = 0.0
@ -143,7 +144,7 @@ def wheel_poll_thread(q):
while True:
evbuf = jsdev.read(8)
_, value, mtype, number = struct.unpack('IhBB', evbuf)
value, mtype, number = struct.unpack('4xhBB', evbuf)
# print(mtype, number, value)
if mtype & 0x02: # wheel & paddles
axis = axis_map[number]
@ -152,38 +153,36 @@ def wheel_poll_thread(q):
fvalue = value / 32767.0
axis_states[axis] = fvalue
normalized = (1 - fvalue) * 50
q.put(str("throttle_%f" % normalized))
q.put("throttle_%f" % normalized)
if axis == "rz": # brake
elif axis == "rz": # brake
fvalue = value / 32767.0
axis_states[axis] = fvalue
normalized = (1 - fvalue) * 50
q.put(str("brake_%f" % normalized))
q.put("brake_%f" % normalized)
if axis == "x": # steer angle
elif axis == "x": # steer angle
fvalue = value / 32767.0
axis_states[axis] = fvalue
normalized = fvalue
q.put(str("steer_%f" % normalized))
q.put("steer_%f" % normalized)
if mtype & 0x01: # buttons
if number in [0, 19]: # X
if value == 1: # press down
q.put(str("cruise_down"))
elif mtype & 0x01: # buttons
if value == 1: # press down
if number in [0, 19]: # X
q.put("cruise_down")
if number in [3, 18]: # triangle
if value == 1: # press down
q.put(str("cruise_up"))
elif number in [3, 18]: # triangle
q.put("cruise_up")
if number in [1, 6]: # square
if value == 1: # press down
q.put(str("cruise_cancel"))
elif number in [1, 6]: # square
q.put("cruise_cancel")
if number in [10, 21]: # R3
if value == 1: # press down
q.put(str("reverse_switch"))
elif number in [10, 21]: # R3
q.put("reverse_switch")
if __name__ == '__main__':
from multiprocessing import Process
p = Process(target=wheel_poll_thread)
from multiprocessing import Process, Queue
q: Queue[str] = Queue()
p = Process(target=wheel_poll_thread, args=(q,))
p.start()