relocate websocket controller code

This commit is contained in:
Thibault Duplessis 2016-12-04 13:59:12 +01:00
parent 47b1b7d08a
commit a7169d9627
2 changed files with 39 additions and 41 deletions

View file

@ -6,7 +6,6 @@ import play.api.http._
import play.api.libs.iteratee.{ Iteratee, Enumerator }
import play.api.libs.json.{ Json, JsValue, JsObject, JsArray, Writes }
import play.api.mvc._, Results._
import play.api.mvc.WebSocket.FrameFormatter
import play.twirl.api.Html
import scalaz.Monoid
@ -52,38 +51,6 @@ private[controllers] trait LilaController
"X-Frame-Options" -> "SAMEORIGIN"
)
protected def Socket[A: FrameFormatter](f: Context => Fu[(Iteratee[A, _], Enumerator[A])]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f map scala.util.Right.apply
}
}
protected def SocketEither[A: FrameFormatter](f: Context => Fu[Either[Result, (Iteratee[A, _], Enumerator[A])]]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f
}
}
protected def SocketOption[A: FrameFormatter](f: Context => Fu[Option[(Iteratee[A, _], Enumerator[A])]]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f map {
case None => Left(NotFound(jsonError("socket resource not found")))
case Some(pair) => Right(pair)
}
}
}
protected def SocketOptionLimited[A: FrameFormatter](consumer: TokenBucket.Consumer, name: String)(f: Context => Fu[Option[(Iteratee[A, _], Enumerator[A])]]) =
rateLimitedSocket[A](consumer, name) { ctx =>
f(ctx) map {
case None => Left(NotFound(jsonError("socket resource not found")))
case Some(pair) => Right(pair)
}
}
protected def Open(f: Context => Fu[Result]): Action[Unit] =
Open(BodyParsers.parse.empty)(f)
@ -357,15 +324,12 @@ private[controllers] trait LilaController
}
}
private val csrfCheck = Env.security.csrfRequestHandler.check _
private val csrfForbiddenResult = Forbidden("Cross origin request forbidden").fuccess
protected val csrfCheck = Env.security.csrfRequestHandler.check _
protected val csrfForbiddenResult = Forbidden("Cross origin request forbidden").fuccess
private def CSRF(req: RequestHeader)(f: => Fu[Result]): Fu[Result] =
if (csrfCheck(req)) f else csrfForbiddenResult
protected def SocketCSRF[A](req: RequestHeader)(f: => Fu[Either[Result, A]]): Fu[Either[Result, A]] =
if (csrfCheck(req)) f else csrfForbiddenResult map Left.apply
protected def XhrOnly(res: => Fu[Result])(implicit ctx: Context) =
if (HTTPRequest isXhr ctx.req) res else notFound

View file

@ -12,9 +12,41 @@ import lila.common.HTTPRequest
trait LilaSocket { self: LilaController =>
def Socket[A: FrameFormatter](f: Context => Fu[(Iteratee[A, _], Enumerator[A])]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f map scala.util.Right.apply
}
}
def SocketEither[A: FrameFormatter](f: Context => Fu[Either[Result, (Iteratee[A, _], Enumerator[A])]]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f
}
}
def SocketOption[A: FrameFormatter](f: Context => Fu[Option[(Iteratee[A, _], Enumerator[A])]]) =
WebSocket.tryAccept[A] { req =>
SocketCSRF(req) {
reqToCtx(req) flatMap f map {
case None => Left(NotFound(jsonError("socket resource not found")))
case Some(pair) => Right(pair)
}
}
}
def SocketOptionLimited[A: FrameFormatter](consumer: TokenBucket.Consumer, name: String)(f: Context => Fu[Option[(Iteratee[A, _], Enumerator[A])]]) =
rateLimitedSocket[A](consumer, name) { ctx =>
f(ctx) map {
case None => Left(NotFound(jsonError("socket resource not found")))
case Some(pair) => Right(pair)
}
}
private type AcceptType[A] = Context => Fu[Either[Result, (Iteratee[A, _], Enumerator[A])]]
private val logger = lila.log("ratelimit")
private val rateLimitLogger = lila.log("ratelimit")
def rateLimitedSocket[A: FrameFormatter](consumer: TokenBucket.Consumer, name: String)(f: AcceptType[A]): WebSocket[A, A] =
WebSocket[A, A] { req =>
@ -26,7 +58,6 @@ trait LilaSocket { self: LilaController =>
val username = ctx.usernameOrAnon
s"user:$username sri:$sri"
}
// logger.debug(s"socket:$name socket connect $ip $userInfo")
f(ctx).map { resultOrSocket =>
resultOrSocket.right.map {
case (readIn, writeOut) => (e, i) => {
@ -35,7 +66,7 @@ trait LilaSocket { self: LilaController =>
consumer(ip).map { credit =>
if (credit >= 0) in
else {
logger.info(s"socket:$name socket close $ip $userInfo $in")
rateLimitLogger.info(s"socket:$name socket close $ip $userInfo $in")
Input.EOF
}
}
@ -46,4 +77,7 @@ trait LilaSocket { self: LilaController =>
}
}
}
private def SocketCSRF[A](req: RequestHeader)(f: => Fu[Either[Result, A]]): Fu[Either[Result, A]] =
if (csrfCheck(req)) f else csrfForbiddenResult map Left.apply
}