248 lines
9.6 KiB
Scala
248 lines
9.6 KiB
Scala
package lila.socket
|
|
|
|
import chess.Centis
|
|
import io.lettuce.core._
|
|
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection
|
|
import java.util.concurrent.atomic.AtomicReference
|
|
import java.util.concurrent.ConcurrentHashMap
|
|
import ornicar.scalalib.Zero
|
|
import play.api.libs.json._
|
|
import scala.concurrent.duration._
|
|
import scala.concurrent.{ Promise, Future }
|
|
|
|
import lila.common.{ Bus, Chronometer }
|
|
import lila.hub.actorApi.relation.ReloadOnlineFriends
|
|
import lila.hub.actorApi.round.{ MoveEvent, Mlat }
|
|
import lila.hub.actorApi.security.CloseAccount
|
|
import lila.hub.actorApi.socket.remote.{ TellSriIn, TellSriOut }
|
|
import lila.hub.actorApi.socket.{ SendTo, SendTos }
|
|
import lila.hub.actorApi.{ Deploy, Announce }
|
|
import lila.hub.{ TrouperMap, Trouper }
|
|
import Socket.{ SocketVersion, GetVersion, Sri, SendToFlag }
|
|
|
|
final class RemoteSocket(
|
|
redisClient: RedisClient,
|
|
notificationActor: akka.actor.ActorSelection,
|
|
bus: lila.common.Bus,
|
|
lifecycle: play.api.inject.ApplicationLifecycle
|
|
) {
|
|
|
|
import RemoteSocket._, Protocol._
|
|
|
|
type UserIds = Set[String]
|
|
|
|
private val requests = new ConcurrentHashMap[Int, Promise[String]](32)
|
|
|
|
def request[R](sendReq: Int => Unit, readRes: String => R): Fu[R] = {
|
|
val id = Math.abs(scala.util.Random.nextInt)
|
|
sendReq(id)
|
|
val promise = Promise[String]
|
|
requests.put(id, promise)
|
|
promise.future map readRes
|
|
}
|
|
|
|
val onlineUserIds: AtomicReference[Set[String]] = new AtomicReference(Set("lichess"))
|
|
|
|
private val watchedGameIds = collection.mutable.Set.empty[String]
|
|
|
|
val baseHandler: Handler = {
|
|
case In.ConnectUser(userId) =>
|
|
bus.publish(lila.hub.actorApi.socket.remote.ConnectUser(userId), 'userActive)
|
|
onlineUserIds.getAndUpdate((x: UserIds) => x + userId)
|
|
case In.DisconnectUsers(userIds) =>
|
|
onlineUserIds.getAndUpdate((x: UserIds) => x -- userIds)
|
|
case In.Watch(gameId) => watchedGameIds += gameId
|
|
case In.Unwatch(gameId) => watchedGameIds -= gameId
|
|
case In.NotifiedBatch(userIds) => notificationActor ! lila.hub.actorApi.notify.NotifiedBatch(userIds)
|
|
case In.FriendsBatch(userIds) => userIds foreach { userId =>
|
|
bus.publish(ReloadOnlineFriends(userId), 'reloadOnlineFriends)
|
|
}
|
|
case In.Lags(lags) =>
|
|
lags foreach (UserLagCache.put _).tupled
|
|
// this shouldn't be necessary... ensure that users are known to be online
|
|
onlineUserIds.getAndUpdate((x: UserIds) => x ++ lags.keys)
|
|
case In.TellSri(sri, userId, typ, msg) =>
|
|
bus.publish(TellSriIn(sri.value, userId, msg), Symbol(s"remoteSocketIn:$typ"))
|
|
case In.WsBoot =>
|
|
logger.warn("Remote socket boot")
|
|
onlineUserIds set Set("lichess")
|
|
watchedGameIds.clear
|
|
case In.ReqResponse(reqId, response) =>
|
|
requests.computeIfPresent(reqId, (_: Int, promise: Promise[String]) => {
|
|
promise success response
|
|
null // remove from promises
|
|
})
|
|
}
|
|
|
|
bus.subscribeFun('moveEvent, 'socketUsers, 'deploy, 'announce, 'mlat, 'sendToFlag, 'remoteSocketOut, 'accountClose, 'shadowban, 'impersonate) {
|
|
case MoveEvent(gameId, fen, move) =>
|
|
if (watchedGameIds(gameId)) send(Out.move(gameId, move, fen))
|
|
case SendTos(userIds, payload) =>
|
|
val connectedUsers = userIds intersect onlineUserIds.get
|
|
if (connectedUsers.nonEmpty) send(Out.tellUsers(connectedUsers, payload))
|
|
case SendTo(userId, payload) if onlineUserIds.get.contains(userId) =>
|
|
send(Out.tellUser(userId, payload))
|
|
case Announce(_, _, json) =>
|
|
send(Out.tellAll(Json.obj("t" -> "announce", "d" -> json)))
|
|
case Mlat(micros) =>
|
|
send(Out.mlat(micros))
|
|
case Socket.SendToFlag(flag, payload) =>
|
|
send(Out.tellFlag(flag, payload))
|
|
case TellSriOut(sri, payload) =>
|
|
send(Out.tellSri(Sri(sri), payload))
|
|
case CloseAccount(userId) =>
|
|
send(Out.disconnectUser(userId))
|
|
case lila.hub.actorApi.mod.Shadowban(userId, v) =>
|
|
send(Out.setTroll(userId, v))
|
|
case lila.hub.actorApi.mod.Impersonate(userId, modId) =>
|
|
send(Out.impersonate(userId, modId))
|
|
}
|
|
|
|
def makeSender(channel: Channel): Sender = new Sender(redisClient.connectPubSub(), channel)
|
|
|
|
private val send: Send = makeSender("site-out").apply _
|
|
|
|
def subscribe(channel: Channel, reader: In.Reader)(handler: Handler): Future[Unit] = {
|
|
val conn = redisClient.connectPubSub()
|
|
conn.addListener(new pubsub.RedisPubSubAdapter[String, String] {
|
|
override def message(_channel: String, message: String): Unit =
|
|
reader(RawMsg(message)) collect handler match {
|
|
case Some(_) => // processed
|
|
case None => logger.warn(s"Unhandled $channel $message")
|
|
}
|
|
})
|
|
val subPromise = Promise[Unit]
|
|
conn.async.subscribe(channel).thenRun {
|
|
new Runnable { def run = subPromise.success(()) }
|
|
}
|
|
subPromise.future
|
|
}
|
|
|
|
lifecycle.addStopHook { () =>
|
|
logger.info("Stopping the Redis pool...")
|
|
Future {
|
|
redisClient.shutdown()
|
|
}
|
|
}
|
|
}
|
|
|
|
object RemoteSocket {
|
|
|
|
private val logger = lila log "socket"
|
|
|
|
type Send = String => Unit
|
|
|
|
final class Sender(conn: StatefulRedisPubSubConnection[String, String], channel: Channel) {
|
|
|
|
def apply(msg: String): Unit = conn.async.publish(channel, msg)
|
|
}
|
|
|
|
object Protocol {
|
|
|
|
final class RawMsg(val path: Path, val args: Args) {
|
|
def get(nb: Int)(f: PartialFunction[Array[String], Option[In]]): Option[In] =
|
|
f.applyOrElse(args.split(" ", nb), (_: Array[String]) => None)
|
|
def all = args split ' '
|
|
}
|
|
def RawMsg(msg: String): RawMsg = {
|
|
val parts = msg.split(" ", 2)
|
|
new RawMsg(parts(0), ~parts.lift(1))
|
|
}
|
|
|
|
trait In
|
|
object In {
|
|
|
|
type Reader = RawMsg => Option[In]
|
|
|
|
case object WsBoot extends In
|
|
case class ConnectUser(userId: String) extends In
|
|
case class DisconnectUsers(userId: Iterable[String]) extends In
|
|
case class ConnectSris(cons: Iterable[(Sri, Option[String])]) extends In
|
|
case class DisconnectSris(sris: Iterable[Sri]) extends In
|
|
case class Watch(gameId: String) extends In
|
|
case class Unwatch(gameId: String) extends In
|
|
case class NotifiedBatch(userIds: Iterable[String]) extends In
|
|
case class Lag(userId: String, lag: Centis) extends In
|
|
case class Lags(lags: Map[String, Centis]) extends In
|
|
case class FriendsBatch(userIds: Iterable[String]) extends In
|
|
case class TellSri(sri: Sri, userId: Option[String], typ: String, msg: JsObject) extends In
|
|
case class ReqResponse(reqId: Int, response: String) extends In
|
|
|
|
val baseReader: Reader = raw => raw.path match {
|
|
case "connect/user" => ConnectUser(raw.args).some
|
|
case "disconnect/users" => DisconnectUsers(commas(raw.args)).some
|
|
case "connect/sris" => ConnectSris {
|
|
commas(raw.args) map (_ split ' ') map { s =>
|
|
(Sri(s(0)), s lift 1)
|
|
}
|
|
}.some
|
|
case "disconnect/sris" => DisconnectSris(commas(raw.args) map Sri.apply).some
|
|
case "watch" => Watch(raw.args).some
|
|
case "unwatch" => Unwatch(raw.args).some
|
|
case "notified/batch" => NotifiedBatch(commas(raw.args)).some
|
|
case "lag" => raw.all |> { s => s lift 1 flatMap parseIntOption map Centis.apply map { Lag(s(0), _) } }
|
|
case "lags" => Lags(commas(raw.args).flatMap {
|
|
_ split ':' match {
|
|
case Array(user, l) => parseIntOption(l) map { lag => user -> Centis(lag) }
|
|
case _ => None
|
|
}
|
|
}.toMap).some
|
|
case "friends/batch" => FriendsBatch(commas(raw.args)).some
|
|
case "tell/sri" => raw.get(3)(tellSriMapper)
|
|
case "req/response" => raw.get(2) {
|
|
case Array(reqId, response) => parseIntOption(reqId) map { ReqResponse(_, response) }
|
|
}
|
|
case "boot" => WsBoot.some
|
|
case _ => none
|
|
}
|
|
|
|
def tellSriMapper: PartialFunction[Array[String], Option[TellSri]] = {
|
|
case Array(sri, user, payload) => for {
|
|
obj <- Json.parse(payload).asOpt[JsObject]
|
|
typ <- obj str "t"
|
|
} yield TellSri(Sri(sri), optional(user), typ, obj)
|
|
}
|
|
|
|
def commas(str: String): Array[String] = if (str == "-") Array.empty else str split ','
|
|
def boolean(str: String): Boolean = str == "+"
|
|
def optional(str: String): Option[String] = if (str == "-") None else Some(str)
|
|
}
|
|
|
|
object Out {
|
|
def move(gameId: String, move: String, fen: String) =
|
|
s"move $gameId $move $fen"
|
|
def tellUser(userId: String, payload: JsObject) =
|
|
s"tell/users $userId ${Json stringify payload}"
|
|
def tellUsers(userIds: Set[String], payload: JsObject) =
|
|
s"tell/users ${commas(userIds)} ${Json stringify payload}"
|
|
def tellAll(payload: JsObject) =
|
|
s"tell/all ${Json stringify payload}"
|
|
def tellFlag(flag: String, payload: JsObject) =
|
|
s"tell/flag $flag ${Json stringify payload}"
|
|
def tellSri(sri: Sri, payload: JsValue) =
|
|
s"tell/sri $sri ${Json stringify payload}"
|
|
def tellSris(sris: Iterable[Sri], payload: JsValue) =
|
|
s"tell/sris ${commas(sris)} ${Json stringify payload}"
|
|
def mlat(micros: Int) =
|
|
s"mlat ${((micros / 100) / 10d)}"
|
|
def disconnectUser(userId: String) =
|
|
s"disconnect/user $userId"
|
|
def setTroll(userId: String, v: Boolean) =
|
|
s"mod/troll/set $userId ${boolean(v)}"
|
|
def impersonate(userId: String, by: Option[String]) =
|
|
s"mod/impersonate $userId ${optional(by)}"
|
|
def boot = "boot"
|
|
|
|
def commas(strs: Iterable[Any]): String = if (strs.isEmpty) "-" else strs mkString ","
|
|
def boolean(v: Boolean): String = if (v) "+" else "-"
|
|
def color(c: chess.Color): String = c.fold("w", "b")
|
|
def optional(str: Option[String]) = str getOrElse "-"
|
|
}
|
|
}
|
|
|
|
type Channel = String
|
|
type Path = String
|
|
type Args = String
|
|
type Handler = PartialFunction[Protocol.In, Unit]
|
|
}
|