lila/app/controllers/LilaSocket.scala
2017-10-09 17:48:56 -04:00

75 lines
2.4 KiB
Scala

package controllers
import play.api.libs.iteratee._
import play.api.mvc._
import play.api.mvc.WebSocket.FrameFormatter
import lila.api.Context
import lila.app._
import lila.common.{ HTTPRequest, IpAddress }
trait LilaSocket { self: LilaController =>
private type Pipe[A] = (Iteratee[A, _], Enumerator[A])
private val notFoundResponse = NotFound(jsonError("socket resource not found"))
protected def SocketEither[A: FrameFormatter](f: Context => Fu[Either[Result, Pipe[A]]]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f
}
}
protected def Socket[A: FrameFormatter](f: Context => Fu[Pipe[A]]) =
SocketEither[A] { ctx =>
f(ctx) map scala.util.Right.apply
}
protected def SocketOption[A: FrameFormatter](f: Context => Fu[Option[Pipe[A]]]) =
SocketEither[A] { ctx =>
f(ctx).map(_ toRight notFoundResponse)
}
protected def SocketOptionLimited[A: FrameFormatter](limiter: lila.memo.RateLimit[IpAddress], name: String)(f: Context => Fu[Option[Pipe[A]]]) =
rateLimitedSocket[A](limiter, name) { ctx =>
f(ctx).map(_ toRight notFoundResponse)
}
private type AcceptType[A] = Context => Fu[Either[Result, Pipe[A]]]
private val rateLimitLogger = lila.log("ratelimit")
private def rateLimitedSocket[A: FrameFormatter](limiter: lila.memo.RateLimit[IpAddress], name: String)(f: AcceptType[A]): WebSocket[A, A] =
WebSocket[A, A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap { ctx =>
val ip = HTTPRequest lastRemoteAddress req
def userInfo = {
val sri = get("sri", req) | "none"
val username = ctx.usernameOrAnon
s"user:$username sri:$sri"
}
f(ctx).map { resultOrSocket =>
resultOrSocket.right.map {
case (readIn, writeOut) => (e: Enumerator[A], i: Iteratee[A, Unit]) => {
writeOut |>> i
e &> Enumeratee.mapInput { in =>
if (limiter(ip, 1)(true)) in
else {
rateLimitLogger.info(s"socket:$name socket close $ip $userInfo $in")
Input.EOF
}
} |>> readIn
()
}
}
}
}
}
}
private def SocketCSRF[A](req: RequestHeader)(f: => Fu[Either[Result, A]]): Fu[Either[Result, A]] =
if (csrfCheck(req)) f else csrfForbiddenResult map Left.apply
}