diff --git a/selfdrive/athena/__init__.py b/selfdrive/athena/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/selfdrive/athena/athenad.py b/selfdrive/athena/athenad.py new file mode 100755 index 000000000..bf3836d07 --- /dev/null +++ b/selfdrive/athena/athenad.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3.7 +import json +import os +import hashlib +import io +import random +import select +import socket +import time +import threading +import base64 +import requests +import queue +from collections import namedtuple +from functools import partial +from jsonrpc import JSONRPCResponseManager, dispatcher +from websocket import create_connection, WebSocketTimeoutException, ABNF +from selfdrive.loggerd.config import ROOT + +import cereal.messaging as messaging +from common import android +from common.api import Api +from common.params import Params +from cereal.services import service_list +from selfdrive.swaglog import cloudlog + +ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai') +HANDLER_THREADS = os.getenv('HANDLER_THREADS', 4) +LOCAL_PORT_WHITELIST = set([8022]) + +dispatcher["echo"] = lambda s: s +payload_queue = queue.Queue() +response_queue = queue.Queue() +upload_queue = queue.Queue() +cancelled_uploads = set() +UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id']) + +def handle_long_poll(ws): + end_event = threading.Event() + + threads = [ + threading.Thread(target=ws_recv, args=(ws, end_event)), + threading.Thread(target=ws_send, args=(ws, end_event)), + threading.Thread(target=upload_handler, args=(end_event,)) + ] + [ + threading.Thread(target=jsonrpc_handler, args=(end_event,)) + for x in range(HANDLER_THREADS) + ] + + for thread in threads: + thread.start() + try: + while not end_event.is_set(): + time.sleep(0.1) + except (KeyboardInterrupt, SystemExit): + end_event.set() + raise + finally: + for i, thread in enumerate(threads): + thread.join() + +def jsonrpc_handler(end_event): + dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) + while not end_event.is_set(): + try: + data = payload_queue.get(timeout=1) + response = JSONRPCResponseManager.handle(data, dispatcher) + response_queue.put_nowait(response) + except queue.Empty: + pass + except Exception as e: + cloudlog.exception("athena jsonrpc handler failed") + response_queue.put_nowait(json.dumps({"error": str(e)})) + +def upload_handler(end_event): + while not end_event.is_set(): + try: + item = upload_queue.get(timeout=1) + if item.id in cancelled_uploads: + cancelled_uploads.remove(item.id) + continue + _do_upload(item) + except queue.Empty: + pass + except Exception: + cloudlog.exception("athena.upload_handler.exception") + +def _do_upload(upload_item): + with open(upload_item.path, "rb") as f: + size = os.fstat(f.fileno()).st_size + return requests.put(upload_item.url, + data=f, + headers={**upload_item.headers, 'Content-Length': str(size)}, + timeout=10) + +# security: user should be able to request any message from their car +@dispatcher.add_method +def getMessage(service=None, timeout=1000): + if service is None or service not in service_list: + raise Exception("invalid service") + + socket = messaging.sub_sock(service, timeout=timeout) + ret = messaging.recv_one(socket) + + if ret is None: + raise TimeoutError + + return ret.to_dict() + +@dispatcher.add_method +def listDataDirectory(): + files = [os.path.relpath(os.path.join(dp, f), ROOT) for dp, dn, fn in os.walk(ROOT) for f in fn] + return files + +@dispatcher.add_method +def reboot(): + thermal_sock = messaging.sub_sock("thermal", timeout=1000) + ret = messaging.recv_one(thermal_sock) + if ret is None or ret.thermal.started: + raise Exception("Reboot unavailable") + + def do_reboot(): + time.sleep(2) + android.reboot() + + threading.Thread(target=do_reboot).start() + + return {"success": 1} + +@dispatcher.add_method +def uploadFileToUrl(fn, url, headers): + if len(fn) == 0 or fn[0] == '/' or '..' in fn: + return 500 + path = os.path.join(ROOT, fn) + if not os.path.exists(path): + return 404 + + item = UploadItem(path=path, url=url, headers=headers, created_at=int(time.time()*1000), id=None) + upload_id = hashlib.sha1(str(item).encode()).hexdigest() + item = item._replace(id=upload_id) + + upload_queue.put_nowait(item) + + return {"enqueued": 1, "item": item._asdict()} + +@dispatcher.add_method +def listUploadQueue(): + return [item._asdict() for item in list(upload_queue.queue)] + +@dispatcher.add_method +def cancelUpload(upload_id): + upload_ids = set(item.id for item in list(upload_queue.queue)) + if upload_id not in upload_ids: + return 404 + + cancelled_uploads.add(upload_id) + return {"success": 1} + +def startLocalProxy(global_end_event, remote_ws_uri, local_port): + try: + if local_port not in LOCAL_PORT_WHITELIST: + raise Exception("Requested local port not whitelisted") + + params = Params() + dongle_id = params.get("DongleId").decode('utf8') + identity_token = Api(dongle_id).get_token() + ws = create_connection(remote_ws_uri, + cookie="jwt=" + identity_token, + enable_multithread=True) + + ssock, csock = socket.socketpair() + local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + local_sock.connect(('127.0.0.1', local_port)) + local_sock.setblocking(0) + + proxy_end_event = threading.Event() + threads = [ + threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)), + threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event)) + ] + for thread in threads: + thread.start() + + return {"success": 1} + except Exception as e: + cloudlog.exception("athenad.startLocalProxy.exception") + raise e + +@dispatcher.add_method +def getPublicKey(): + if not os.path.isfile('/persist/comma/id_rsa.pub'): + return None + + with open('/persist/comma/id_rsa.pub', 'r') as f: + return f.read() + +@dispatcher.add_method +def getSshAuthorizedKeys(): + return Params().get("GithubSshKeys", encoding='utf8') or '' + +@dispatcher.add_method +def getSimInfo(): + sim_state = android.getprop("gsm.sim.state").split(",") + network_type = android.getprop("gsm.network.type").split(',') + mcc_mnc = android.getprop("gsm.sim.operator.numeric") or None + + sim_id = android.parse_service_call_string(android.service_call(['iphonesubinfo', '11'])) + cell_data_state = android.parse_service_call_unpack(android.service_call(['phone', '46']), ">q") + cell_data_connected = (cell_data_state == 2) + + return { + 'sim_id': sim_id, + 'mcc_mnc': mcc_mnc, + 'network_type': network_type, + 'sim_state': sim_state, + 'data_connected': cell_data_connected + } + +@dispatcher.add_method +def takeSnapshot(): + from selfdrive.camerad.snapshot.snapshot import snapshot, jpeg_write + ret = snapshot() + if ret is not None: + def b64jpeg(x): + if x is not None: + f = io.BytesIO() + jpeg_write(f, x) + return base64.b64encode(f.getvalue()).decode("utf-8") + else: + return None + return {'jpegBack': b64jpeg(ret[0]), + 'jpegFront': b64jpeg(ret[1])} + else: + raise Exception("not available while camerad is started") + +def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): + while not (end_event.is_set() or global_end_event.is_set()): + try: + data = ws.recv() + local_sock.sendall(data) + except WebSocketTimeoutException: + pass + except Exception: + cloudlog.exception("athenad.ws_proxy_recv.exception") + break + + ssock.close() + local_sock.close() + end_event.set() + +def ws_proxy_send(ws, local_sock, signal_sock, end_event): + while not end_event.is_set(): + try: + r, _, _ = select.select((local_sock, signal_sock), (), ()) + if r: + if r[0].fileno() == signal_sock.fileno(): + # got end signal from ws_proxy_recv + end_event.set() + break + data = local_sock.recv(4096) + if not data: + # local_sock is dead + end_event.set() + break + + ws.send(data, ABNF.OPCODE_BINARY) + except Exception: + cloudlog.exception("athenad.ws_proxy_send.exception") + end_event.set() + +def ws_recv(ws, end_event): + while not end_event.is_set(): + try: + data = ws.recv() + payload_queue.put_nowait(data) + except WebSocketTimeoutException: + pass + except Exception: + cloudlog.exception("athenad.ws_recv.exception") + end_event.set() + +def ws_send(ws, end_event): + while not end_event.is_set(): + try: + response = response_queue.get(timeout=1) + ws.send(response.json) + except queue.Empty: + pass + except Exception: + cloudlog.exception("athenad.ws_send.exception") + end_event.set() + +def backoff(retries): + return random.randrange(0, min(128, int(2 ** retries))) + +def main(gctx=None): + params = Params() + dongle_id = params.get("DongleId").decode('utf-8') + ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id + + api = Api(dongle_id) + + conn_retries = 0 + while 1: + try: + ws = create_connection(ws_uri, + cookie="jwt=" + api.get_token(), + enable_multithread=True) + cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri) + ws.settimeout(1) + conn_retries = 0 + handle_long_poll(ws) + except (KeyboardInterrupt, SystemExit): + break + except Exception: + cloudlog.exception("athenad.main.exception") + conn_retries += 1 + + time.sleep(backoff(conn_retries)) + +if __name__ == "__main__": + main() diff --git a/selfdrive/athena/manage_athenad.py b/selfdrive/athena/manage_athenad.py new file mode 100755 index 000000000..2954c70b8 --- /dev/null +++ b/selfdrive/athena/manage_athenad.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import time +from multiprocessing import Process + +import selfdrive.crash as crash +from common.params import Params +from selfdrive.launcher import launcher +from selfdrive.swaglog import cloudlog +from selfdrive.version import version, dirty + +ATHENA_MGR_PID_PARAM = "AthenadPid" + +def main(): + params = Params() + dongle_id = params.get("DongleId").decode('utf-8') + cloudlog.bind_global(dongle_id=dongle_id, version=version, dirty=dirty, is_eon=True) + crash.bind_user(id=dongle_id) + crash.bind_extra(version=version, dirty=dirty, is_eon=True) + crash.install() + + try: + while 1: + cloudlog.info("starting athena daemon") + proc = Process(name='athenad', target=launcher, args=('selfdrive.athena.athenad',)) + proc.start() + proc.join() + cloudlog.event("athenad exited", exitcode=proc.exitcode) + time.sleep(5) + except: + cloudlog.exception("manage_athenad.exception") + finally: + params.delete(ATHENA_MGR_PID_PARAM) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/selfdrive/athena/test.py b/selfdrive/athena/test.py new file mode 100755 index 000000000..0bedfdeb7 --- /dev/null +++ b/selfdrive/athena/test.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +import json +import os +import requests +import tempfile +import time +import threading +import queue +import unittest + +from multiprocessing import Process +from pathlib import Path +from unittest import mock +from websocket import ABNF +from websocket._exceptions import WebSocketConnectionClosedException + +from selfdrive.athena import athenad +from selfdrive.athena.athenad import dispatcher +from selfdrive.athena.test_helpers import MockWebsocket, MockParams, MockApi, EchoSocket, with_http_server +from cereal import messaging + +class TestAthenadMethods(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.SOCKET_PORT = 45454 + athenad.ROOT = tempfile.mkdtemp() + athenad.Params = MockParams + athenad.Api = MockApi + athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT]) + + def test_echo(self): + assert dispatcher["echo"]("bob") == "bob" + + def test_getMessage(self): + with self.assertRaises(TimeoutError) as _: + dispatcher["getMessage"]("controlsState") + + def send_thermal(): + messaging.context = messaging.Context() + pub_sock = messaging.pub_sock("thermal") + start = time.time() + + while time.time() - start < 1: + msg = messaging.new_message() + msg.init('thermal') + pub_sock.send(msg.to_bytes()) + time.sleep(0.01) + + p = Process(target=send_thermal) + p.start() + time.sleep(0.1) + try: + thermal = dispatcher["getMessage"]("thermal") + assert thermal['thermal'] + finally: + p.terminate() + + def test_listDataDirectory(self): + print(dispatcher["listDataDirectory"]()) + + @with_http_server + def test_do_upload(self, host): + fn = os.path.join(athenad.ROOT, 'qlog.bz2') + Path(fn).touch() + + try: + item = athenad.UploadItem(path=fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='') + try: + athenad._do_upload(item) + except requests.exceptions.ConnectionError: + pass + + item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') + resp = athenad._do_upload(item) + self.assertEqual(resp.status_code, 201) + finally: + os.unlink(fn) + + @with_http_server + def test_uploadFileToUrl(self, host): + not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {}) + self.assertEqual(not_exists_resp, 404) + + fn = os.path.join(athenad.ROOT, 'qlog.bz2') + Path(fn).touch() + + try: + resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {}) + self.assertEqual(resp['enqueued'], 1) + self.assertDictContainsSubset({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}, resp['item']) + self.assertIsNotNone(resp['item'].get('id')) + self.assertEqual(athenad.upload_queue.qsize(), 1) + finally: + athenad.upload_queue = queue.Queue() + os.unlink(fn) + + @with_http_server + def test_upload_handler(self, host): + fn = os.path.join(athenad.ROOT, 'qlog.bz2') + Path(fn).touch() + item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') + + end_event = threading.Event() + thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) + thread.start() + + athenad.upload_queue.put_nowait(item) + try: + now = time.time() + while time.time() - now < 5: + if athenad.upload_queue.qsize() == 0: + break + self.assertEqual(athenad.upload_queue.qsize(), 0) + finally: + end_event.set() + athenad.upload_queue = queue.Queue() + os.unlink(fn) + + def test_cancelUpload(self): + item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') + athenad.upload_queue.put_nowait(item) + dispatcher["cancelUpload"](item.id) + + self.assertIn(item.id, athenad.cancelled_uploads) + + end_event = threading.Event() + thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) + thread.start() + try: + now = time.time() + while time.time() - now < 5: + if athenad.upload_queue.qsize() == 0 and len(athenad.cancelled_uploads) == 0: + break + self.assertEqual(athenad.upload_queue.qsize(), 0) + self.assertEqual(len(athenad.cancelled_uploads), 0) + finally: + end_event.set() + athenad.upload_queue = queue.Queue() + + def test_listUploadQueue(self): + item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') + athenad.upload_queue.put_nowait(item) + + try: + items = dispatcher["listUploadQueue"]() + self.assertEqual(len(items), 1) + self.assertDictEqual(items[0], item._asdict()) + finally: + athenad.upload_queue = queue.Queue() + + @mock.patch('selfdrive.athena.athenad.create_connection') + def test_startLocalProxy(self, mock_create_connection): + end_event = threading.Event() + + ws_recv = queue.Queue() + ws_send = queue.Queue() + mock_ws = MockWebsocket(ws_recv, ws_send) + mock_create_connection.return_value = mock_ws + + echo_socket = EchoSocket(self.SOCKET_PORT) + socket_thread = threading.Thread(target=echo_socket.run) + socket_thread.start() + + athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT) + + ws_recv.put_nowait(b'ping') + try: + recv = ws_send.get(timeout=5) + assert recv == (b'ping', ABNF.OPCODE_BINARY), recv + finally: + # signal websocket close to athenad.ws_proxy_recv + ws_recv.put_nowait(WebSocketConnectionClosedException()) + socket_thread.join() + + def test_getSshAuthorizedKeys(self): + keys = dispatcher["getSshAuthorizedKeys"]() + self.assertEqual(keys, MockParams().params["GithubSshKeys"].decode('utf-8')) + + def test_jsonrpc_handler(self): + end_event = threading.Event() + thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,)) + thread.daemon = True + thread.start() + athenad.payload_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0})) + try: + resp = athenad.response_queue.get(timeout=3) + self.assertDictEqual(resp.data, {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}) + finally: + end_event.set() + thread.join() + +if __name__ == '__main__': + unittest.main() diff --git a/selfdrive/athena/test_helpers.py b/selfdrive/athena/test_helpers.py new file mode 100644 index 000000000..2335ce89c --- /dev/null +++ b/selfdrive/athena/test_helpers.py @@ -0,0 +1,114 @@ +import http.server +import multiprocessing +import queue +import random +import requests +import socket +import time +from functools import wraps +from multiprocessing import Process + +class EchoSocket(): + def __init__(self, port): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.bind(('127.0.0.1', port)) + self.socket.listen(1) + + def run(self): + conn, client_address = self.socket.accept() + conn.settimeout(5.0) + + try: + while True: + data = conn.recv(4096) + if data: + print(f'EchoSocket got {data}') + conn.sendall(data) + else: + break + finally: + conn.shutdown(0) + conn.close() + self.socket.shutdown(0) + self.socket.close() + +class MockApi(): + def __init__(self, dongle_id): + pass + + def get_token(self): + return "fake-token" + +class MockParams(): + def __init__(self): + self.params = { + "DongleId": b"0000000000000000", + "GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private" + } + + def get(self, k, encoding=None): + ret = self.params.get(k) + if ret is not None and encoding is not None: + ret = ret.decode(encoding) + return ret + +class MockWebsocket(): + def __init__(self, recv_queue, send_queue): + self.recv_queue = recv_queue + self.send_queue = send_queue + + def recv(self): + data = self.recv_queue.get() + if isinstance(data, Exception): + raise data + return data + + def send(self, data, opcode): + self.send_queue.put_nowait((data, opcode)) + +class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler): + def do_PUT(self): + length = int(self.headers['Content-Length']) + self.rfile.read(length) + self.send_response(201, "Created") + self.end_headers() + +def http_server(port_queue, **kwargs): + while 1: + try: + port = random.randrange(40000, 50000) + port_queue.put(port) + http.server.test(**kwargs, port=port) + except OSError as e: + if e.errno == 98: + continue + +def with_http_server(func): + @wraps(func) + def inner(*args, **kwargs): + port_queue = multiprocessing.Queue() + host = '127.0.0.1' + p = Process(target=http_server, + args=(port_queue,), + kwargs={ + 'HandlerClass': HTTPRequestHandler, + 'bind': host}) + p.start() + now = time.time() + port = None + while 1: + if time.time() - now > 5: + raise Exception('HTTP Server did not start') + try: + port = port_queue.get(timeout=0.1) + requests.put(f'http://{host}:{port}/qlog.bz2', data='') + break + except (requests.exceptions.ConnectionError, queue.Empty): + time.sleep(0.1) + + try: + return func(*args, f'http://{host}:{port}', **kwargs) + finally: + p.terminate() + + return inner