diff --git a/app/controllers/OAuth.scala b/app/controllers/OAuth.scala index 3e540e23fb..8d2b48f27d 100644 --- a/app/controllers/OAuth.scala +++ b/app/controllers/OAuth.scala @@ -121,7 +121,7 @@ final class OAuth(env: Env) extends LilaController(env) { def tokenRevoke = Scoped() { implicit req => _ => HTTPRequest.bearer(req) ?? { token => - env.oAuth.tokenApi.revoke(AccessToken.Id.from(token)) inject NoContent + env.oAuth.tokenApi.revoke(token) inject NoContent } } diff --git a/app/controllers/OAuthToken.scala b/app/controllers/OAuthToken.scala index 4480ff728c..28eb183e1f 100644 --- a/app/controllers/OAuthToken.scala +++ b/app/controllers/OAuthToken.scala @@ -39,7 +39,7 @@ final class OAuthToken(env: Env) extends LilaController(env) { } def delete(id: String) = - Auth { _ => _ => - tokenApi.revoke(AccessToken.Id(id)) inject Redirect(routes.OAuthToken.index).flashSuccess + Auth { _ => me => + tokenApi.revokeById(AccessToken.Id(id), me) inject Redirect(routes.OAuthToken.index).flashSuccess } } diff --git a/modules/oauth/src/main/AccessTokenApi.scala b/modules/oauth/src/main/AccessTokenApi.scala index effbfcfd4b..e5fefeaa47 100644 --- a/modules/oauth/src/main/AccessTokenApi.scala +++ b/modules/oauth/src/main/AccessTokenApi.scala @@ -104,8 +104,15 @@ final class AccessTokenApi(coll: Coll, cacheApi: lila.memo.CacheApi, userRepo: U } yield AccessTokenApi.Client(origin, usedAt, scopes) } - def revoke(id: AccessToken.Id): Funit = - coll.delete.one($id(id)).map(_ => invalidateCached(id)) + def revokeById(id: AccessToken.Id, user: User): Funit = + coll.delete + .one( + $doc( + F.id -> id, + F.userId -> user.id + ) + ) + .map(_ => invalidateCached(id)) def revokeByClientOrigin(clientOrigin: String, user: User): Funit = coll @@ -130,6 +137,11 @@ final class AccessTokenApi(coll: Coll, cacheApi: lila.memo.CacheApi, userRepo: U .map(_ => invalidate.flatMap(_.getAsOpt[AccessToken.Id](F.id)).foreach(invalidateCached)) } + def revoke(bearer: Bearer) = { + val id = AccessToken.Id.from(bearer) + coll.delete.one($id(id)).map(_ => invalidateCached(id)) + } + def get(bearer: Bearer) = accessTokenCache.get(AccessToken.Id.from(bearer)) private val accessTokenCache =