nopenpilot/selfdrive/athena/athenad.py

732 lines
22 KiB
Python
Executable File

#!/usr/bin/env python3
import base64
import hashlib
import io
import json
import os
import queue
import random
import select
import socket
import subprocess
import sys
import tempfile
import threading
import time
from collections import namedtuple
from datetime import datetime
from functools import partial
from typing import Any, Dict
import requests
from jsonrpc import JSONRPCResponseManager, dispatcher
from websocket import (ABNF, WebSocketException, WebSocketTimeoutException,
create_connection)
import cereal.messaging as messaging
from cereal import log
from cereal.services import service_list
from common.api import Api
from common.basedir import PERSIST
from common.file_helpers import CallbackReader
from common.params import Params
from common.realtime import sec_since_boot
from selfdrive.hardware import HARDWARE, PC, TICI
from selfdrive.loggerd.config import ROOT
from selfdrive.loggerd.xattr_cache import getxattr, setxattr
from selfdrive.statsd import STATS_DIR
from selfdrive.swaglog import SWAGLOG_DIR, cloudlog
from selfdrive.version import get_commit, get_origin, get_short_branch, get_version
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai')
HANDLER_THREADS = int(os.getenv('HANDLER_THREADS', "4"))
LOCAL_PORT_WHITELIST = {8022}
LOG_ATTR_NAME = 'user.upload'
LOG_ATTR_VALUE_MAX_UNIX_TIME = int.to_bytes(2147483647, 4, sys.byteorder)
RECONNECT_TIMEOUT_S = 70
RETRY_DELAY = 10 # seconds
MAX_RETRY_COUNT = 30 # Try for at most 5 minutes if upload fails immediately
MAX_AGE = 31 * 24 * 3600 # seconds
WS_FRAME_SIZE = 4096
NetworkType = log.DeviceState.NetworkType
dispatcher["echo"] = lambda s: s
recv_queue: Any = queue.Queue()
send_queue: Any = queue.Queue()
upload_queue: Any = queue.Queue()
low_priority_send_queue: Any = queue.Queue()
log_recv_queue: Any = queue.Queue()
cancelled_uploads: Any = set()
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count', 'current', 'progress', 'allow_cellular'], defaults=(0, False, 0, False))
cur_upload_items: Dict[int, Any] = {}
class AbortTransferException(Exception):
pass
class UploadQueueCache():
params = Params()
@staticmethod
def initialize(upload_queue):
try:
upload_queue_json = UploadQueueCache.params.get("AthenadUploadQueue")
if upload_queue_json is not None:
for item in json.loads(upload_queue_json):
upload_queue.put(UploadItem(**item))
except Exception:
cloudlog.exception("athena.UploadQueueCache.initialize.exception")
@staticmethod
def cache(upload_queue):
try:
items = [i._asdict() for i in upload_queue.queue if i.id not in cancelled_uploads]
UploadQueueCache.params.put("AthenadUploadQueue", json.dumps(items))
except Exception:
cloudlog.exception("athena.UploadQueueCache.cache.exception")
def handle_long_poll(ws):
end_event = threading.Event()
threads = [
threading.Thread(target=ws_recv, args=(ws, end_event), name='ws_recv'),
threading.Thread(target=ws_send, args=(ws, end_event), name='ws_send'),
threading.Thread(target=upload_handler, args=(end_event,), name='upload_handler'),
threading.Thread(target=log_handler, args=(end_event,), name='log_handler'),
threading.Thread(target=stat_handler, args=(end_event,), name='stat_handler'),
] + [
threading.Thread(target=jsonrpc_handler, args=(end_event,), name=f'worker_{x}')
for x in range(HANDLER_THREADS)
]
for thread in threads:
thread.start()
try:
while not end_event.is_set():
time.sleep(0.1)
except (KeyboardInterrupt, SystemExit):
end_event.set()
raise
finally:
for thread in threads:
cloudlog.debug(f"athena.joining {thread.name}")
thread.join()
def jsonrpc_handler(end_event):
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
while not end_event.is_set():
try:
data = recv_queue.get(timeout=1)
if "method" in data:
cloudlog.debug(f"athena.jsonrpc_handler.call_method {data}")
response = JSONRPCResponseManager.handle(data, dispatcher)
send_queue.put_nowait(response.json)
elif "id" in data and ("result" in data or "error" in data):
log_recv_queue.put_nowait(data)
else:
raise Exception("not a valid request or response")
except queue.Empty:
pass
except Exception as e:
cloudlog.exception("athena jsonrpc handler failed")
send_queue.put_nowait(json.dumps({"error": str(e)}))
def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None:
if cur_upload_items[tid].retry_count < MAX_RETRY_COUNT:
item = cur_upload_items[tid]
new_retry_count = item.retry_count + 1 if increase_count else item.retry_count
item = item._replace(
retry_count=new_retry_count,
progress=0,
current=False
)
upload_queue.put_nowait(item)
UploadQueueCache.cache(upload_queue)
cur_upload_items[tid] = None
for _ in range(RETRY_DELAY):
time.sleep(1)
if end_event.is_set():
break
def upload_handler(end_event: threading.Event) -> None:
sm = messaging.SubMaster(['deviceState'])
tid = threading.get_ident()
while not end_event.is_set():
cur_upload_items[tid] = None
try:
cur_upload_items[tid] = upload_queue.get(timeout=1)._replace(current=True)
if cur_upload_items[tid].id in cancelled_uploads:
cancelled_uploads.remove(cur_upload_items[tid].id)
continue
# Remove item if too old
age = datetime.now() - datetime.fromtimestamp(cur_upload_items[tid].created_at / 1000)
if age.total_seconds() > MAX_AGE:
cloudlog.event("athena.upload_handler.expired", item=cur_upload_items[tid], error=True)
continue
# Check if uploading over metered connection is allowed
sm.update(0)
metered = sm['deviceState'].networkMetered
network_type = sm['deviceState'].networkType.raw
if metered and (not cur_upload_items[tid].allow_cellular):
retry_upload(tid, end_event, False)
continue
try:
def cb(sz, cur):
# Abort transfer if connection changed to metered after starting upload
sm.update(0)
metered = sm['deviceState'].networkMetered
if metered and (not cur_upload_items[tid].allow_cellular):
raise AbortTransferException
cur_upload_items[tid] = cur_upload_items[tid]._replace(progress=cur / sz if sz else 1)
fn = cur_upload_items[tid].path
try:
sz = os.path.getsize(fn)
except OSError:
sz = -1
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered)
response = _do_upload(cur_upload_items[tid], cb)
if response.status_code not in (200, 201, 403, 412):
cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event)
else:
cloudlog.event("athena.upload_handler.success", fn=fn, sz=sz, network_type=network_type, metered=metered)
UploadQueueCache.cache(upload_queue)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.SSLError):
cloudlog.event("athena.upload_handler.timeout", fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event)
except AbortTransferException:
cloudlog.event("athena.upload_handler.abort", fn=fn, sz=sz, network_type=network_type, metered=metered)
retry_upload(tid, end_event, False)
except queue.Empty:
pass
except Exception:
cloudlog.exception("athena.upload_handler.exception")
def _do_upload(upload_item, callback=None):
with open(upload_item.path, "rb") as f:
size = os.fstat(f.fileno()).st_size
if callback:
f = CallbackReader(f, callback, size)
return requests.put(upload_item.url,
data=f,
headers={**upload_item.headers, 'Content-Length': str(size)},
timeout=30)
# security: user should be able to request any message from their car
@dispatcher.add_method
def getMessage(service=None, timeout=1000):
if service is None or service not in service_list:
raise Exception("invalid service")
socket = messaging.sub_sock(service, timeout=timeout)
ret = messaging.recv_one(socket)
if ret is None:
raise TimeoutError
return ret.to_dict()
@dispatcher.add_method
def getVersion():
return {
"version": get_version(),
"remote": get_origin(),
"branch": get_short_branch(),
"commit": get_commit(),
}
@dispatcher.add_method
def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=None):
destination = {
"latitude": latitude,
"longitude": longitude,
"place_name": place_name,
"place_details": place_details,
}
Params().put("NavDestination", json.dumps(destination))
return {"success": 1}
def scan_dir(path, prefix):
files = list()
# only walk directories that match the prefix
# (glob and friends traverse entire dir tree)
with os.scandir(path) as i:
for e in i:
rel_path = os.path.relpath(e.path, ROOT)
if e.is_dir(follow_symlinks=False):
# add trailing slash
rel_path = os.path.join(rel_path, '')
# if prefix is a partial dir name, current dir will start with prefix
# if prefix is a partial file name, prefix with start with dir name
if rel_path.startswith(prefix) or prefix.startswith(rel_path):
files.extend(scan_dir(e.path, prefix))
else:
if rel_path.startswith(prefix):
files.append(rel_path)
return files
@dispatcher.add_method
def listDataDirectory(prefix=''):
return scan_dir(ROOT, prefix)
@dispatcher.add_method
def reboot():
sock = messaging.sub_sock("deviceState", timeout=1000)
ret = messaging.recv_one(sock)
if ret is None or ret.deviceState.started:
raise Exception("Reboot unavailable")
def do_reboot():
time.sleep(2)
HARDWARE.reboot()
threading.Thread(target=do_reboot).start()
return {"success": 1}
@dispatcher.add_method
def uploadFileToUrl(fn, url, headers):
return uploadFilesToUrls([{
"fn": fn,
"url": url,
"headers": headers,
}])
@dispatcher.add_method
def uploadFilesToUrls(files_data):
items = []
failed = []
for file in files_data:
fn = file.get('fn', '')
if len(fn) == 0 or fn[0] == '/' or '..' in fn or 'url' not in file:
failed.append(fn)
continue
path = os.path.join(ROOT, fn)
if not os.path.exists(path):
failed.append(fn)
continue
item = UploadItem(
path=path,
url=file['url'],
headers=file.get('headers', {}),
created_at=int(time.time() * 1000),
id=None,
allow_cellular=file.get('allow_cellular', False),
)
upload_id = hashlib.sha1(str(item).encode()).hexdigest()
item = item._replace(id=upload_id)
upload_queue.put_nowait(item)
items.append(item._asdict())
UploadQueueCache.cache(upload_queue)
resp = {"enqueued": len(items), "items": items}
if failed:
resp["failed"] = failed
return resp
@dispatcher.add_method
def listUploadQueue():
items = list(upload_queue.queue) + list(cur_upload_items.values())
return [i._asdict() for i in items if (i is not None) and (i.id not in cancelled_uploads)]
@dispatcher.add_method
def cancelUpload(upload_id):
if not isinstance(upload_id, list):
upload_id = [upload_id]
uploading_ids = {item.id for item in list(upload_queue.queue)}
cancelled_ids = uploading_ids.intersection(upload_id)
if len(cancelled_ids) == 0:
return 404
cancelled_uploads.update(cancelled_ids)
return {"success": 1}
@dispatcher.add_method
def primeActivated(activated):
return {"success": 1}
@dispatcher.add_method
def setBandwithLimit(upload_speed_kbps, download_speed_kbps):
if not TICI:
return {"success": 0, "error": "only supported on comma three"}
try:
HARDWARE.set_bandwidth_limit(upload_speed_kbps, download_speed_kbps)
return {"success": 1}
except subprocess.CalledProcessError as e:
return {"success": 0, "error": "failed to set limit", "stdout": e.stdout, "stderr": e.stderr}
def startLocalProxy(global_end_event, remote_ws_uri, local_port):
try:
if local_port not in LOCAL_PORT_WHITELIST:
raise Exception("Requested local port not whitelisted")
cloudlog.debug("athena.startLocalProxy.starting")
dongle_id = Params().get("DongleId").decode('utf8')
identity_token = Api(dongle_id).get_token()
ws = create_connection(remote_ws_uri,
cookie="jwt=" + identity_token,
enable_multithread=True)
ssock, csock = socket.socketpair()
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
local_sock.connect(('127.0.0.1', local_port))
local_sock.setblocking(0)
proxy_end_event = threading.Event()
threads = [
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)),
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event))
]
for thread in threads:
thread.start()
cloudlog.debug("athena.startLocalProxy.started")
return {"success": 1}
except Exception as e:
cloudlog.exception("athenad.startLocalProxy.exception")
raise e
@dispatcher.add_method
def getPublicKey():
if not os.path.isfile(PERSIST + '/comma/id_rsa.pub'):
return None
with open(PERSIST + '/comma/id_rsa.pub') as f:
return f.read()
@dispatcher.add_method
def getSshAuthorizedKeys():
return Params().get("GithubSshKeys", encoding='utf8') or ''
@dispatcher.add_method
def getSimInfo():
return HARDWARE.get_sim_info()
@dispatcher.add_method
def getNetworkType():
return HARDWARE.get_network_type()
@dispatcher.add_method
def getNetworkMetered():
network_type = HARDWARE.get_network_type()
return HARDWARE.get_network_metered(network_type)
@dispatcher.add_method
def getNetworks():
return HARDWARE.get_networks()
@dispatcher.add_method
def takeSnapshot():
from selfdrive.camerad.snapshot.snapshot import jpeg_write, snapshot
ret = snapshot()
if ret is not None:
def b64jpeg(x):
if x is not None:
f = io.BytesIO()
jpeg_write(f, x)
return base64.b64encode(f.getvalue()).decode("utf-8")
else:
return None
return {'jpegBack': b64jpeg(ret[0]),
'jpegFront': b64jpeg(ret[1])}
else:
raise Exception("not available while camerad is started")
def get_logs_to_send_sorted():
# TODO: scan once then use inotify to detect file creation/deletion
curr_time = int(time.time())
logs = []
for log_entry in os.listdir(SWAGLOG_DIR):
log_path = os.path.join(SWAGLOG_DIR, log_entry)
try:
time_sent = int.from_bytes(getxattr(log_path, LOG_ATTR_NAME), sys.byteorder)
except (ValueError, TypeError):
time_sent = 0
# assume send failed and we lost the response if sent more than one hour ago
if not time_sent or curr_time - time_sent > 3600:
logs.append(log_entry)
# excluding most recent (active) log file
return sorted(logs)[:-1]
def log_handler(end_event):
if PC:
return
log_files = []
last_scan = 0
while not end_event.is_set():
try:
curr_scan = sec_since_boot()
if curr_scan - last_scan > 10:
log_files = get_logs_to_send_sorted()
last_scan = curr_scan
# send one log
curr_log = None
if len(log_files) > 0:
log_entry = log_files.pop() # newest log file
cloudlog.debug(f"athena.log_handler.forward_request {log_entry}")
try:
curr_time = int(time.time())
log_path = os.path.join(SWAGLOG_DIR, log_entry)
setxattr(log_path, LOG_ATTR_NAME, int.to_bytes(curr_time, 4, sys.byteorder))
with open(log_path) as f:
jsonrpc = {
"method": "forwardLogs",
"params": {
"logs": f.read()
},
"jsonrpc": "2.0",
"id": log_entry
}
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
curr_log = log_entry
except OSError:
pass # file could be deleted by log rotation
# wait for response up to ~100 seconds
# always read queue at least once to process any old responses that arrive
for _ in range(100):
if end_event.is_set():
break
try:
log_resp = json.loads(log_recv_queue.get(timeout=1))
log_entry = log_resp.get("id")
log_success = "result" in log_resp and log_resp["result"].get("success")
cloudlog.debug(f"athena.log_handler.forward_response {log_entry} {log_success}")
if log_entry and log_success:
log_path = os.path.join(SWAGLOG_DIR, log_entry)
try:
setxattr(log_path, LOG_ATTR_NAME, LOG_ATTR_VALUE_MAX_UNIX_TIME)
except OSError:
pass # file could be deleted by log rotation
if curr_log == log_entry:
break
except queue.Empty:
if curr_log is None:
break
except Exception:
cloudlog.exception("athena.log_handler.exception")
def stat_handler(end_event):
while not end_event.is_set():
last_scan = 0
curr_scan = sec_since_boot()
try:
if curr_scan - last_scan > 10:
stat_filenames = list(filter(lambda name: not name.startswith(tempfile.gettempprefix()), os.listdir(STATS_DIR)))
if len(stat_filenames) > 0:
stat_path = os.path.join(STATS_DIR, stat_filenames[0])
with open(stat_path) as f:
jsonrpc = {
"method": "storeStats",
"params": {
"stats": f.read()
},
"jsonrpc": "2.0",
"id": stat_filenames[0]
}
low_priority_send_queue.put_nowait(json.dumps(jsonrpc))
os.remove(stat_path)
last_scan = curr_scan
except Exception:
cloudlog.exception("athena.stat_handler.exception")
time.sleep(0.1)
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event):
while not (end_event.is_set() or global_end_event.is_set()):
try:
data = ws.recv()
local_sock.sendall(data)
except WebSocketTimeoutException:
pass
except Exception:
cloudlog.exception("athenad.ws_proxy_recv.exception")
break
cloudlog.debug("athena.ws_proxy_recv closing sockets")
ssock.close()
local_sock.close()
cloudlog.debug("athena.ws_proxy_recv done closing sockets")
end_event.set()
def ws_proxy_send(ws, local_sock, signal_sock, end_event):
while not end_event.is_set():
try:
r, _, _ = select.select((local_sock, signal_sock), (), ())
if r:
if r[0].fileno() == signal_sock.fileno():
# got end signal from ws_proxy_recv
end_event.set()
break
data = local_sock.recv(4096)
if not data:
# local_sock is dead
end_event.set()
break
ws.send(data, ABNF.OPCODE_BINARY)
except Exception:
cloudlog.exception("athenad.ws_proxy_send.exception")
end_event.set()
cloudlog.debug("athena.ws_proxy_send closing sockets")
signal_sock.close()
cloudlog.debug("athena.ws_proxy_send done closing sockets")
def ws_recv(ws, end_event):
last_ping = int(sec_since_boot() * 1e9)
while not end_event.is_set():
try:
opcode, data = ws.recv_data(control_frame=True)
if opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
if opcode == ABNF.OPCODE_TEXT:
data = data.decode("utf-8")
recv_queue.put_nowait(data)
elif opcode == ABNF.OPCODE_PING:
last_ping = int(sec_since_boot() * 1e9)
Params().put("LastAthenaPingTime", str(last_ping))
except WebSocketTimeoutException:
ns_since_last_ping = int(sec_since_boot() * 1e9) - last_ping
if ns_since_last_ping > RECONNECT_TIMEOUT_S * 1e9:
cloudlog.exception("athenad.ws_recv.timeout")
end_event.set()
except Exception:
cloudlog.exception("athenad.ws_recv.exception")
end_event.set()
def ws_send(ws, end_event):
while not end_event.is_set():
try:
try:
data = send_queue.get_nowait()
except queue.Empty:
data = low_priority_send_queue.get(timeout=1)
for i in range(0, len(data), WS_FRAME_SIZE):
frame = data[i:i+WS_FRAME_SIZE]
last = i + WS_FRAME_SIZE >= len(data)
opcode = ABNF.OPCODE_TEXT if i == 0 else ABNF.OPCODE_CONT
ws.send_frame(ABNF.create_frame(frame, opcode, last))
except queue.Empty:
pass
except Exception:
cloudlog.exception("athenad.ws_send.exception")
end_event.set()
def backoff(retries):
return random.randrange(0, min(128, int(2 ** retries)))
def main():
params = Params()
dongle_id = params.get("DongleId", encoding='utf-8')
UploadQueueCache.initialize(upload_queue)
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id
api = Api(dongle_id)
conn_retries = 0
while 1:
try:
cloudlog.event("athenad.main.connecting_ws", ws_uri=ws_uri)
ws = create_connection(ws_uri,
cookie="jwt=" + api.get_token(),
enable_multithread=True,
timeout=30.0)
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri)
params.delete("PrimeRedirected")
conn_retries = 0
cur_upload_items.clear()
handle_long_poll(ws)
except (KeyboardInterrupt, SystemExit):
break
except (ConnectionError, TimeoutError, WebSocketException):
conn_retries += 1
params.delete("PrimeRedirected")
params.delete("LastAthenaPingTime")
except socket.timeout:
try:
r = requests.get("http://api.commadotai.com/v1/me", allow_redirects=False,
headers={"User-Agent": f"openpilot-{get_version()}"}, timeout=15.0)
if r.status_code == 302 and r.headers['Location'].startswith("http://u.web2go.com"):
params.put_bool("PrimeRedirected", True)
except Exception:
cloudlog.exception("athenad.socket_timeout.exception")
params.delete("LastAthenaPingTime")
except Exception:
cloudlog.exception("athenad.main.exception")
conn_retries += 1
params.delete("PrimeRedirected")
params.delete("LastAthenaPingTime")
time.sleep(backoff(conn_retries))
if __name__ == "__main__":
main()