selfdrive/athenad

pull/960/head
George Hotz 2020-01-17 10:54:24 -08:00
parent 84560ccd55
commit 341c0da987
5 changed files with 665 additions and 0 deletions

View File

View File

@ -0,0 +1,322 @@
#!/usr/bin/env python3.7
import json
import os
import hashlib
import io
import random
import select
import socket
import time
import threading
import base64
import requests
import queue
from collections import namedtuple
from functools import partial
from jsonrpc import JSONRPCResponseManager, dispatcher
from websocket import create_connection, WebSocketTimeoutException, ABNF
from selfdrive.loggerd.config import ROOT
import cereal.messaging as messaging
from common import android
from common.api import Api
from common.params import Params
from cereal.services import service_list
from selfdrive.swaglog import cloudlog
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai')
HANDLER_THREADS = os.getenv('HANDLER_THREADS', 4)
LOCAL_PORT_WHITELIST = set([8022])
dispatcher["echo"] = lambda s: s
payload_queue = queue.Queue()
response_queue = queue.Queue()
upload_queue = queue.Queue()
cancelled_uploads = set()
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id'])
def handle_long_poll(ws):
end_event = threading.Event()
threads = [
threading.Thread(target=ws_recv, args=(ws, end_event)),
threading.Thread(target=ws_send, args=(ws, end_event)),
threading.Thread(target=upload_handler, args=(end_event,))
] + [
threading.Thread(target=jsonrpc_handler, args=(end_event,))
for x in range(HANDLER_THREADS)
]
for thread in threads:
thread.start()
try:
while not end_event.is_set():
time.sleep(0.1)
except (KeyboardInterrupt, SystemExit):
end_event.set()
raise
finally:
for i, thread in enumerate(threads):
thread.join()
def jsonrpc_handler(end_event):
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
while not end_event.is_set():
try:
data = payload_queue.get(timeout=1)
response = JSONRPCResponseManager.handle(data, dispatcher)
response_queue.put_nowait(response)
except queue.Empty:
pass
except Exception as e:
cloudlog.exception("athena jsonrpc handler failed")
response_queue.put_nowait(json.dumps({"error": str(e)}))
def upload_handler(end_event):
while not end_event.is_set():
try:
item = upload_queue.get(timeout=1)
if item.id in cancelled_uploads:
cancelled_uploads.remove(item.id)
continue
_do_upload(item)
except queue.Empty:
pass
except Exception:
cloudlog.exception("athena.upload_handler.exception")
def _do_upload(upload_item):
with open(upload_item.path, "rb") as f:
size = os.fstat(f.fileno()).st_size
return requests.put(upload_item.url,
data=f,
headers={**upload_item.headers, 'Content-Length': str(size)},
timeout=10)
# security: user should be able to request any message from their car
@dispatcher.add_method
def getMessage(service=None, timeout=1000):
if service is None or service not in service_list:
raise Exception("invalid service")
socket = messaging.sub_sock(service, timeout=timeout)
ret = messaging.recv_one(socket)
if ret is None:
raise TimeoutError
return ret.to_dict()
@dispatcher.add_method
def listDataDirectory():
files = [os.path.relpath(os.path.join(dp, f), ROOT) for dp, dn, fn in os.walk(ROOT) for f in fn]
return files
@dispatcher.add_method
def reboot():
thermal_sock = messaging.sub_sock("thermal", timeout=1000)
ret = messaging.recv_one(thermal_sock)
if ret is None or ret.thermal.started:
raise Exception("Reboot unavailable")
def do_reboot():
time.sleep(2)
android.reboot()
threading.Thread(target=do_reboot).start()
return {"success": 1}
@dispatcher.add_method
def uploadFileToUrl(fn, url, headers):
if len(fn) == 0 or fn[0] == '/' or '..' in fn:
return 500
path = os.path.join(ROOT, fn)
if not os.path.exists(path):
return 404
item = UploadItem(path=path, url=url, headers=headers, created_at=int(time.time()*1000), id=None)
upload_id = hashlib.sha1(str(item).encode()).hexdigest()
item = item._replace(id=upload_id)
upload_queue.put_nowait(item)
return {"enqueued": 1, "item": item._asdict()}
@dispatcher.add_method
def listUploadQueue():
return [item._asdict() for item in list(upload_queue.queue)]
@dispatcher.add_method
def cancelUpload(upload_id):
upload_ids = set(item.id for item in list(upload_queue.queue))
if upload_id not in upload_ids:
return 404
cancelled_uploads.add(upload_id)
return {"success": 1}
def startLocalProxy(global_end_event, remote_ws_uri, local_port):
try:
if local_port not in LOCAL_PORT_WHITELIST:
raise Exception("Requested local port not whitelisted")
params = Params()
dongle_id = params.get("DongleId").decode('utf8')
identity_token = Api(dongle_id).get_token()
ws = create_connection(remote_ws_uri,
cookie="jwt=" + identity_token,
enable_multithread=True)
ssock, csock = socket.socketpair()
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
local_sock.connect(('127.0.0.1', local_port))
local_sock.setblocking(0)
proxy_end_event = threading.Event()
threads = [
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)),
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event))
]
for thread in threads:
thread.start()
return {"success": 1}
except Exception as e:
cloudlog.exception("athenad.startLocalProxy.exception")
raise e
@dispatcher.add_method
def getPublicKey():
if not os.path.isfile('/persist/comma/id_rsa.pub'):
return None
with open('/persist/comma/id_rsa.pub', 'r') as f:
return f.read()
@dispatcher.add_method
def getSshAuthorizedKeys():
return Params().get("GithubSshKeys", encoding='utf8') or ''
@dispatcher.add_method
def getSimInfo():
sim_state = android.getprop("gsm.sim.state").split(",")
network_type = android.getprop("gsm.network.type").split(',')
mcc_mnc = android.getprop("gsm.sim.operator.numeric") or None
sim_id = android.parse_service_call_string(android.service_call(['iphonesubinfo', '11']))
cell_data_state = android.parse_service_call_unpack(android.service_call(['phone', '46']), ">q")
cell_data_connected = (cell_data_state == 2)
return {
'sim_id': sim_id,
'mcc_mnc': mcc_mnc,
'network_type': network_type,
'sim_state': sim_state,
'data_connected': cell_data_connected
}
@dispatcher.add_method
def takeSnapshot():
from selfdrive.camerad.snapshot.snapshot import snapshot, jpeg_write
ret = snapshot()
if ret is not None:
def b64jpeg(x):
if x is not None:
f = io.BytesIO()
jpeg_write(f, x)
return base64.b64encode(f.getvalue()).decode("utf-8")
else:
return None
return {'jpegBack': b64jpeg(ret[0]),
'jpegFront': b64jpeg(ret[1])}
else:
raise Exception("not available while camerad is started")
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event):
while not (end_event.is_set() or global_end_event.is_set()):
try:
data = ws.recv()
local_sock.sendall(data)
except WebSocketTimeoutException:
pass
except Exception:
cloudlog.exception("athenad.ws_proxy_recv.exception")
break
ssock.close()
local_sock.close()
end_event.set()
def ws_proxy_send(ws, local_sock, signal_sock, end_event):
while not end_event.is_set():
try:
r, _, _ = select.select((local_sock, signal_sock), (), ())
if r:
if r[0].fileno() == signal_sock.fileno():
# got end signal from ws_proxy_recv
end_event.set()
break
data = local_sock.recv(4096)
if not data:
# local_sock is dead
end_event.set()
break
ws.send(data, ABNF.OPCODE_BINARY)
except Exception:
cloudlog.exception("athenad.ws_proxy_send.exception")
end_event.set()
def ws_recv(ws, end_event):
while not end_event.is_set():
try:
data = ws.recv()
payload_queue.put_nowait(data)
except WebSocketTimeoutException:
pass
except Exception:
cloudlog.exception("athenad.ws_recv.exception")
end_event.set()
def ws_send(ws, end_event):
while not end_event.is_set():
try:
response = response_queue.get(timeout=1)
ws.send(response.json)
except queue.Empty:
pass
except Exception:
cloudlog.exception("athenad.ws_send.exception")
end_event.set()
def backoff(retries):
return random.randrange(0, min(128, int(2 ** retries)))
def main(gctx=None):
params = Params()
dongle_id = params.get("DongleId").decode('utf-8')
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id
api = Api(dongle_id)
conn_retries = 0
while 1:
try:
ws = create_connection(ws_uri,
cookie="jwt=" + api.get_token(),
enable_multithread=True)
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri)
ws.settimeout(1)
conn_retries = 0
handle_long_poll(ws)
except (KeyboardInterrupt, SystemExit):
break
except Exception:
cloudlog.exception("athenad.main.exception")
conn_retries += 1
time.sleep(backoff(conn_retries))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import time
from multiprocessing import Process
import selfdrive.crash as crash
from common.params import Params
from selfdrive.launcher import launcher
from selfdrive.swaglog import cloudlog
from selfdrive.version import version, dirty
ATHENA_MGR_PID_PARAM = "AthenadPid"
def main():
params = Params()
dongle_id = params.get("DongleId").decode('utf-8')
cloudlog.bind_global(dongle_id=dongle_id, version=version, dirty=dirty, is_eon=True)
crash.bind_user(id=dongle_id)
crash.bind_extra(version=version, dirty=dirty, is_eon=True)
crash.install()
try:
while 1:
cloudlog.info("starting athena daemon")
proc = Process(name='athenad', target=launcher, args=('selfdrive.athena.athenad',))
proc.start()
proc.join()
cloudlog.event("athenad exited", exitcode=proc.exitcode)
time.sleep(5)
except:
cloudlog.exception("manage_athenad.exception")
finally:
params.delete(ATHENA_MGR_PID_PARAM)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,193 @@
#!/usr/bin/env python3
import json
import os
import requests
import tempfile
import time
import threading
import queue
import unittest
from multiprocessing import Process
from pathlib import Path
from unittest import mock
from websocket import ABNF
from websocket._exceptions import WebSocketConnectionClosedException
from selfdrive.athena import athenad
from selfdrive.athena.athenad import dispatcher
from selfdrive.athena.test_helpers import MockWebsocket, MockParams, MockApi, EchoSocket, with_http_server
from cereal import messaging
class TestAthenadMethods(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.SOCKET_PORT = 45454
athenad.ROOT = tempfile.mkdtemp()
athenad.Params = MockParams
athenad.Api = MockApi
athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT])
def test_echo(self):
assert dispatcher["echo"]("bob") == "bob"
def test_getMessage(self):
with self.assertRaises(TimeoutError) as _:
dispatcher["getMessage"]("controlsState")
def send_thermal():
messaging.context = messaging.Context()
pub_sock = messaging.pub_sock("thermal")
start = time.time()
while time.time() - start < 1:
msg = messaging.new_message()
msg.init('thermal')
pub_sock.send(msg.to_bytes())
time.sleep(0.01)
p = Process(target=send_thermal)
p.start()
time.sleep(0.1)
try:
thermal = dispatcher["getMessage"]("thermal")
assert thermal['thermal']
finally:
p.terminate()
def test_listDataDirectory(self):
print(dispatcher["listDataDirectory"]())
@with_http_server
def test_do_upload(self, host):
fn = os.path.join(athenad.ROOT, 'qlog.bz2')
Path(fn).touch()
try:
item = athenad.UploadItem(path=fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='')
try:
athenad._do_upload(item)
except requests.exceptions.ConnectionError:
pass
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='')
resp = athenad._do_upload(item)
self.assertEqual(resp.status_code, 201)
finally:
os.unlink(fn)
@with_http_server
def test_uploadFileToUrl(self, host):
not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {})
self.assertEqual(not_exists_resp, 404)
fn = os.path.join(athenad.ROOT, 'qlog.bz2')
Path(fn).touch()
try:
resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {})
self.assertEqual(resp['enqueued'], 1)
self.assertDictContainsSubset({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}, resp['item'])
self.assertIsNotNone(resp['item'].get('id'))
self.assertEqual(athenad.upload_queue.qsize(), 1)
finally:
athenad.upload_queue = queue.Queue()
os.unlink(fn)
@with_http_server
def test_upload_handler(self, host):
fn = os.path.join(athenad.ROOT, 'qlog.bz2')
Path(fn).touch()
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='')
end_event = threading.Event()
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
thread.start()
athenad.upload_queue.put_nowait(item)
try:
now = time.time()
while time.time() - now < 5:
if athenad.upload_queue.qsize() == 0:
break
self.assertEqual(athenad.upload_queue.qsize(), 0)
finally:
end_event.set()
athenad.upload_queue = queue.Queue()
os.unlink(fn)
def test_cancelUpload(self):
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id')
athenad.upload_queue.put_nowait(item)
dispatcher["cancelUpload"](item.id)
self.assertIn(item.id, athenad.cancelled_uploads)
end_event = threading.Event()
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
thread.start()
try:
now = time.time()
while time.time() - now < 5:
if athenad.upload_queue.qsize() == 0 and len(athenad.cancelled_uploads) == 0:
break
self.assertEqual(athenad.upload_queue.qsize(), 0)
self.assertEqual(len(athenad.cancelled_uploads), 0)
finally:
end_event.set()
athenad.upload_queue = queue.Queue()
def test_listUploadQueue(self):
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id')
athenad.upload_queue.put_nowait(item)
try:
items = dispatcher["listUploadQueue"]()
self.assertEqual(len(items), 1)
self.assertDictEqual(items[0], item._asdict())
finally:
athenad.upload_queue = queue.Queue()
@mock.patch('selfdrive.athena.athenad.create_connection')
def test_startLocalProxy(self, mock_create_connection):
end_event = threading.Event()
ws_recv = queue.Queue()
ws_send = queue.Queue()
mock_ws = MockWebsocket(ws_recv, ws_send)
mock_create_connection.return_value = mock_ws
echo_socket = EchoSocket(self.SOCKET_PORT)
socket_thread = threading.Thread(target=echo_socket.run)
socket_thread.start()
athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT)
ws_recv.put_nowait(b'ping')
try:
recv = ws_send.get(timeout=5)
assert recv == (b'ping', ABNF.OPCODE_BINARY), recv
finally:
# signal websocket close to athenad.ws_proxy_recv
ws_recv.put_nowait(WebSocketConnectionClosedException())
socket_thread.join()
def test_getSshAuthorizedKeys(self):
keys = dispatcher["getSshAuthorizedKeys"]()
self.assertEqual(keys, MockParams().params["GithubSshKeys"].decode('utf-8'))
def test_jsonrpc_handler(self):
end_event = threading.Event()
thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,))
thread.daemon = True
thread.start()
athenad.payload_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0}))
try:
resp = athenad.response_queue.get(timeout=3)
self.assertDictEqual(resp.data, {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'})
finally:
end_event.set()
thread.join()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,114 @@
import http.server
import multiprocessing
import queue
import random
import requests
import socket
import time
from functools import wraps
from multiprocessing import Process
class EchoSocket():
def __init__(self, port):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind(('127.0.0.1', port))
self.socket.listen(1)
def run(self):
conn, client_address = self.socket.accept()
conn.settimeout(5.0)
try:
while True:
data = conn.recv(4096)
if data:
print(f'EchoSocket got {data}')
conn.sendall(data)
else:
break
finally:
conn.shutdown(0)
conn.close()
self.socket.shutdown(0)
self.socket.close()
class MockApi():
def __init__(self, dongle_id):
pass
def get_token(self):
return "fake-token"
class MockParams():
def __init__(self):
self.params = {
"DongleId": b"0000000000000000",
"GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private"
}
def get(self, k, encoding=None):
ret = self.params.get(k)
if ret is not None and encoding is not None:
ret = ret.decode(encoding)
return ret
class MockWebsocket():
def __init__(self, recv_queue, send_queue):
self.recv_queue = recv_queue
self.send_queue = send_queue
def recv(self):
data = self.recv_queue.get()
if isinstance(data, Exception):
raise data
return data
def send(self, data, opcode):
self.send_queue.put_nowait((data, opcode))
class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_PUT(self):
length = int(self.headers['Content-Length'])
self.rfile.read(length)
self.send_response(201, "Created")
self.end_headers()
def http_server(port_queue, **kwargs):
while 1:
try:
port = random.randrange(40000, 50000)
port_queue.put(port)
http.server.test(**kwargs, port=port)
except OSError as e:
if e.errno == 98:
continue
def with_http_server(func):
@wraps(func)
def inner(*args, **kwargs):
port_queue = multiprocessing.Queue()
host = '127.0.0.1'
p = Process(target=http_server,
args=(port_queue,),
kwargs={
'HandlerClass': HTTPRequestHandler,
'bind': host})
p.start()
now = time.time()
port = None
while 1:
if time.time() - now > 5:
raise Exception('HTTP Server did not start')
try:
port = port_queue.get(timeout=0.1)
requests.put(f'http://{host}:{port}/qlog.bz2', data='')
break
except (requests.exceptions.ConnectionError, queue.Empty):
time.sleep(0.1)
try:
return func(*args, f'http://{host}:{port}', **kwargs)
finally:
p.terminate()
return inner