lila/modules/socket/src/main/RemoteSocket.scala

328 lines
12 KiB
Scala

package lila.socket
import akka.actor.{ ActorSystem, CoordinatedShutdown }
import cats.data.NonEmptyList
import chess.{ Centis, Color }
import io.lettuce.core._
import io.lettuce.core.pubsub.{ StatefulRedisPubSubConnection => PubSub }
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.ConcurrentHashMap
import play.api.libs.json._
import scala.concurrent.duration._
import scala.concurrent.{ Future, Promise }
import scala.util.chaining._
import Socket.Sri
import lila.common.{ Bus, Lilakka }
import lila.hub.actorApi.Announce
import lila.hub.actorApi.relation.{ Follow, UnFollow }
import lila.hub.actorApi.round.Mlat
import lila.hub.actorApi.security.CloseAccount
import lila.hub.actorApi.socket.remote.{ TellSriIn, TellSriOut, TellUserIn }
import lila.hub.actorApi.socket.{ ApiUserIsOnline, SendTo, SendTos }
final class RemoteSocket(
redisClient: RedisClient,
notification: lila.hub.actors.Notification,
shutdown: CoordinatedShutdown
)(implicit
ec: scala.concurrent.ExecutionContext,
system: ActorSystem
) {
import RemoteSocket._, Protocol._
private var stopping = false
private type UserIds = Set[String]
private val requests = new ConcurrentHashMap[Int, Promise[String]](32)
def request[R](sendReq: Int => Unit, readRes: String => R): Fu[R] = {
val id = lila.common.ThreadLocalRandom.nextPositiveInt()
sendReq(id)
val promise = Promise[String]()
requests.put(id, promise)
promise.future map readRes
}
val onlineUserIds: AtomicReference[Set[String]] = new AtomicReference(Set("lichess"))
val baseHandler: Handler = {
case In.ConnectUser(userId) =>
onlineUserIds.getAndUpdate(_ + userId).unit
case In.DisconnectUsers(userIds) =>
onlineUserIds.getAndUpdate(_ -- userIds).unit
case In.NotifiedBatch(userIds) => notification ! lila.hub.actorApi.notify.NotifiedBatch(userIds)
case In.Lags(lags) =>
lags foreach (UserLagCache.put _).tupled
// this shouldn't be necessary... ensure that users are known to be online
onlineUserIds.getAndUpdate((x: UserIds) => x ++ lags.keys).unit
case In.TellSri(sri, userId, typ, msg) =>
Bus.publish(TellSriIn(sri.value, userId, msg), s"remoteSocketIn:$typ")
case In.TellUser(userId, typ, msg) =>
Bus.publish(TellUserIn(userId, msg), s"remoteSocketIn:$typ")
case In.ReqResponse(reqId, response) =>
requests
.computeIfPresent(
reqId,
(_: Int, promise: Promise[String]) => {
promise success response
null // remove from promises
}
)
.unit
case In.Ping(id) => send(Out.pong(id))
case In.WsBoot =>
logger.warn("Remote socket boot")
onlineUserIds set Set("lichess")
}
Bus.subscribeFun(
"socketUsers",
"announce",
"mlat",
"sendToFlag",
"remoteSocketOut",
"accountClose",
"shadowban",
"impersonate",
"relation",
"onlineApiUsers"
) {
case SendTos(userIds, payload) =>
val connectedUsers = userIds intersect onlineUserIds.get
if (connectedUsers.nonEmpty) send(Out.tellUsers(connectedUsers, payload))
case SendTo(userId, payload) if onlineUserIds.get.contains(userId) =>
send(Out.tellUser(userId, payload))
case Announce(_, _, json) =>
send(Out.tellAll(Json.obj("t" -> "announce", "d" -> json)))
case Mlat(micros) =>
send(Out.mlat(micros))
case Socket.SendToFlag(flag, payload) =>
send(Out.tellFlag(flag, payload))
case TellSriOut(sri, payload) =>
send(Out.tellSri(Sri(sri), payload))
case CloseAccount(userId) =>
send(Out.disconnectUser(userId))
case lila.hub.actorApi.mod.Shadowban(userId, v) =>
send(Out.setTroll(userId, v))
case lila.hub.actorApi.mod.Impersonate(userId, modId) =>
send(Out.impersonate(userId, modId))
case ApiUserIsOnline(userId, value) =>
send(Out.apiUserOnline(userId, value))
if (value) onlineUserIds.getAndUpdate(_ + userId).unit
case Follow(u1, u2) => send(Out.follow(u1, u2))
case UnFollow(u1, u2) => send(Out.unfollow(u1, u2))
}
final class StoppableSender(conn: PubSub[String, String], channel: Channel) extends Sender {
def apply(msg: String): Unit = if (!stopping) conn.async.publish(channel, msg).unit
def sticky(_id: String, msg: String): Unit = apply(msg)
}
final class RoundRobinSender(conn: PubSub[String, String], channel: Channel, parallelism: Int)
extends Sender {
def apply(msg: String): Unit = publish(msg.hashCode.abs % parallelism, msg)
// use the ID to select the channel, not the entire message
def sticky(id: String, msg: String): Unit = publish(id.hashCode.abs % parallelism, msg)
private def publish(subChannel: Int, msg: String) =
if (!stopping) conn.async.publish(s"$channel:$subChannel", msg).unit
}
def makeSender(channel: Channel, parallelism: Int = 1): Sender =
if (parallelism > 1) new RoundRobinSender(redisClient.connectPubSub(), channel, parallelism)
else new StoppableSender(redisClient.connectPubSub(), channel)
private val send: Send = makeSender("site-out").apply _
def subscribe(channel: Channel, reader: In.Reader)(handler: Handler): Funit =
connectAndSubscribe(channel) { message =>
reader(RawMsg(message)) collect handler match {
case Some(_) => // processed
case None => logger.warn(s"Unhandled $channel $message")
}
}
def subscribeRoundRobin(channel: Channel, reader: In.Reader, parallelism: Int)(
handler: Handler
): Funit =
// subscribe to main channel
subscribe(channel, reader)(handler) >> {
// and subscribe to subchannels
(0 to parallelism)
.map { index =>
subscribe(s"$channel:$index", reader)(handler)
}
.sequenceFu
.void
}
private def connectAndSubscribe(channel: Channel)(f: String => Unit): Funit = {
val conn = redisClient.connectPubSub()
conn.addListener(new pubsub.RedisPubSubAdapter[String, String] {
override def message(_channel: String, message: String): Unit = f(message)
})
val subPromise = Promise[Unit]()
conn.async.subscribe(channel).thenRun { () =>
subPromise.success(())
}
subPromise.future
}
Lilakka.shutdown(shutdown, _.PhaseBeforeServiceUnbind, "Telling lila-ws we're stopping") { () =>
request[Unit](
id => send(Protocol.Out.stop(id)),
res => logger.info(s"lila-ws says: $res")
).withTimeout(1 second)
.addFailureEffect(e => logger.error("lila-ws stop", e))
.nevermind
}
Lilakka.shutdown(shutdown, _.PhaseServiceUnbind, "Stopping the socket redis pool") { () =>
Future {
stopping = true
redisClient.shutdown()
}
}
}
object RemoteSocket {
private val logger = lila log "socket"
type Send = String => Unit
trait Sender {
def apply(msg: String): Unit
def sticky(_id: String, msg: String): Unit
}
object Protocol {
final class RawMsg(val path: Path, val args: Args) {
def get(nb: Int)(f: PartialFunction[Array[String], Option[In]]): Option[In] =
f.applyOrElse(args.split(" ", nb), (_: Array[String]) => None)
def all = args split ' '
}
def RawMsg(msg: String): RawMsg = {
val parts = msg.split(" ", 2)
new RawMsg(parts(0), ~parts.lift(1))
}
trait In
object In {
type Reader = RawMsg => Option[In]
case object WsBoot extends In
case class ConnectUser(userId: String) extends In
case class DisconnectUsers(userId: Iterable[String]) extends In
case class ConnectSris(cons: Iterable[(Sri, Option[String])]) extends In
case class DisconnectSris(sris: Iterable[Sri]) extends In
case class NotifiedBatch(userIds: Iterable[String]) extends In
case class Lag(userId: String, lag: Centis) extends In
case class Lags(lags: Map[String, Centis]) extends In
case class TellSri(sri: Sri, userId: Option[String], typ: String, msg: JsObject) extends In
case class TellUser(userId: String, typ: String, msg: JsObject) extends In
case class ReqResponse(reqId: Int, response: String) extends In
case class Ping(id: String) extends In
val baseReader: Reader = raw =>
raw.path match {
case "connect/user" => ConnectUser(raw.args).some
case "disconnect/users" => DisconnectUsers(commas(raw.args)).some
case "connect/sris" =>
ConnectSris {
commas(raw.args) map (_ split ' ') map { s =>
(Sri(s(0)), s lift 1)
}
}.some
case "disconnect/sris" => DisconnectSris(commas(raw.args) map Sri.apply).some
case "notified/batch" => NotifiedBatch(commas(raw.args)).some
case "lag" =>
raw.all pipe { s =>
s lift 1 flatMap (_.toIntOption) map Centis.apply map { Lag(s(0), _) }
}
case "lags" =>
Lags(commas(raw.args).flatMap {
_ split ':' match {
case Array(user, l) =>
l.toIntOption map { lag =>
user -> Centis(lag)
}
case _ => None
}
}.toMap).some
case "tell/sri" => raw.get(3)(tellSriMapper)
case "tell/user" =>
raw.get(2) { case Array(user, payload) =>
for {
obj <- Json.parse(payload).asOpt[JsObject]
typ <- obj str "t"
} yield TellUser(user, typ, obj)
}
case "req/response" =>
raw.get(2) { case Array(reqId, response) =>
reqId.toIntOption map { ReqResponse(_, response) }
}
case "ping" => Ping(raw.args).some
case "boot" => WsBoot.some
case _ => none
}
def tellSriMapper: PartialFunction[Array[String], Option[TellSri]] = { case Array(sri, user, payload) =>
for {
obj <- Json.parse(payload).asOpt[JsObject]
typ <- obj str "t"
} yield TellSri(Sri(sri), optional(user), typ, obj)
}
def commas(str: String): Array[String] = if (str == "-") Array.empty else str split ','
def boolean(str: String): Boolean = str == "+"
def optional(str: String): Option[String] = if (str == "-") None else Some(str)
}
object Out {
def tellUser(userId: String, payload: JsObject) =
s"tell/users $userId ${Json stringify payload}"
def tellUsers(userIds: Set[String], payload: JsObject) =
s"tell/users ${commas(userIds)} ${Json stringify payload}"
def tellAll(payload: JsObject) =
s"tell/all ${Json stringify payload}"
def tellFlag(flag: String, payload: JsObject) =
s"tell/flag $flag ${Json stringify payload}"
def tellSri(sri: Sri, payload: JsValue) =
s"tell/sri $sri ${Json stringify payload}"
def tellSris(sris: Iterable[Sri], payload: JsValue) =
s"tell/sris ${commas(sris)} ${Json stringify payload}"
def mlat(micros: Int) =
s"mlat ${((micros / 100) / 10d)}"
def disconnectUser(userId: String) =
s"disconnect/user $userId"
def setTroll(userId: String, v: Boolean) =
s"mod/troll/set $userId ${boolean(v)}"
def impersonate(userId: String, by: Option[String]) =
s"mod/impersonate $userId ${optional(by)}"
def follow(u1: String, u2: String) = s"rel/follow $u1 $u2"
def unfollow(u1: String, u2: String) = s"rel/unfollow $u1 $u2"
def apiUserOnline(u: String, v: Boolean) = s"api/online $u ${boolean(v)}"
def boot = "boot"
def pong(id: String) = s"pong $id"
def stop(reqId: Int) = s"lila/stop $reqId"
def commas(strs: Iterable[Any]): String = if (strs.isEmpty) "-" else strs mkString ","
def boolean(v: Boolean): String = if (v) "+" else "-"
def optional(str: Option[String]) = str getOrElse "-"
def color(c: Color): String = c.fold("w", "b")
def color(c: Option[Color]): String = optional(c.map(_.fold("w", "b")))
}
}
type Channel = String
type Path = String
type Args = String
type Handler = PartialFunction[Protocol.In, Unit]
}