From 0e1d65445524b9819f701b67e11bebd03121964c Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 11 May 2017 16:13:10 -0700 Subject: Permissions token refactors --- src/main/scala/xyz/driver/core/rest.scala | 166 +++++++++++++++++------------- 1 file changed, 96 insertions(+), 70 deletions(-) (limited to 'src/main/scala/xyz/driver/core/rest.scala') diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index bacb120..3f61246 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -27,13 +27,13 @@ import scalaz.{ListT, OptionT} package rest { object `package` { - import akka.http.scaladsl.server.{RequestContext => _, _} + import akka.http.scaladsl.server._ import Directives._ - def serviceContext: Directive1[RequestContext] = extract(ctx => extractServiceContext(ctx.request)) + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - def extractServiceContext(request: HttpRequest): RequestContext = - new RequestContext(extractTrackingId(request), extractContextHeaders(request)) + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers @@ -92,38 +92,39 @@ package rest { } } - class RequestContext(val trackingId: String = generators.nextUuid().toString, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { + class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) def permissionsToken: Option[PermissionsToken] = contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedRequestContext[U] = - new AuthorizedRequestContext(trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user) + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user) override def hashCode(): Int = Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) override def equals(obj: Any): Boolean = obj match { - case ctx: RequestContext => trackingId == ctx.trackingId && contextHeaders == ctx.contextHeaders - case _ => false + case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders + case _ => false } override def toString: String = s"RequestContext($trackingId, $contextHeaders)" } - class AuthorizedRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val contextHeaders: Map[String, String] = - Map.empty[String, String], - val authenticatedUser: U) - extends RequestContext { + class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedRequestContext[U] = - new AuthorizedRequestContext[U]( + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( trackingId, contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), authenticatedUser) @@ -131,8 +132,8 @@ package rest { override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false } override def toString: String = s"AuthenticatedRequestContext($trackingId, $contextHeaders, $authenticatedUser)" @@ -152,24 +153,79 @@ package rest { val SetPermissionsTokenHeader = "set-permissions" } + final case class AuthorizationResult(authorized: Boolean, token: Option[PermissionsToken]) + object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = false, None) + } + trait Authorization[U <: User] { def userHasPermissions(permissions: Seq[Permission])( - implicit ctx: AuthorizedRequestContext[U]): OptionT[Future, - (Map[Permission, Boolean], PermissionsToken)] + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] } - class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + class AlwaysAllowAuthorization[U <: User](implicit execution: ExecutionContext) extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = + Future.successful(AuthorizationResult(authorized = true, ctx.permissionsToken)) + } + + class CachedTokenAuthorization[U <: User](publicKey: PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = { + import spray.json._ + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(ctx.authenticatedUser.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- jwtJson.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(true)) + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } + } + + class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( - implicit ctx: AuthorizedRequestContext[U]): OptionT[Future, - (Map[Permission, Boolean], PermissionsToken)] = - OptionT.optionT(Future.successful(Option((permissions.map(_ -> true).toMap, PermissionsToken(""))))) + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = { + def callAuthorizations( + remainingAuthorizations: List[Authorization[U]] = authorizations.toList): Future[AuthorizationResult] = { + remainingAuthorizations match { + case auth :: Nil => auth.userHasPermissions(permissions) + case auth :: rest => + auth + .userHasPermissions(permissions) + .flatMap( + result => + if (result.authorized) Future.successful(result) + else callAuthorizations(rest)) + case Nil => Future.successful(AuthorizationResult.unauthorized) + } + } + + callAuthorizations() + } + } - abstract class AuthProvider[U <: User](val authorization: Authorization[U], - val permissionsTokenPublicKey: PublicKey, - log: Logger)(implicit execution: ExecutionContext) { + abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( + implicit execution: ExecutionContext) { - import akka.http.scaladsl.server.{RequestContext => _, _} + import akka.http.scaladsl.server._ import Directives._ /** @@ -179,20 +235,21 @@ package rest { * @param ctx set of request values which can be relevant to authenticate user * @return authenticated user */ - def authenticatedUser(implicit ctx: RequestContext): OptionT[Future, U] + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] /** * Verifies if request is authenticated and authorized to have `permissions` */ - def authorize(permissions: Permission*): Directive1[AuthorizedRequestContext[U]] = { + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { serviceContext flatMap { ctx => onComplete { (for { authToken <- OptionT.optionT(Future.successful(ctx.authToken)) user <- authenticatedUser(ctx) authCtx = ctx.withAuthenticatedUser(authToken, user) - (authorizationResult, permissionsToken) <- userHasPermission(user, permissions)(authCtx) - } yield (authCtx.withPermissionsToken(permissionsToken), authorizationResult)).run + authorizationResult <- authorization.userHasPermissions(permissions)(authCtx).toOptionT + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + } yield (cachedPermissionsAuthCtx, authorizationResult.authorized)).run } flatMap { case Success(Some((authCtx, true))) => provide(authCtx) case Success(Some((authCtx, false))) => @@ -211,37 +268,6 @@ package rest { } } } - - protected def userHasPermission(user: U, permissions: Seq[Permission])( - ctx: AuthorizedRequestContext[U]): OptionT[Future, (Boolean, PermissionsToken)] = { - import spray.json._ - - def authorizedByToken: OptionT[Future, (Boolean, PermissionsToken)] = { - OptionT.optionT(Future.successful(for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, permissionsTokenPublicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user, otherwise return None to call permissions service - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - - permissionsMap <- jwtJson.fields.get("permissions").map(_.asJsObject.fields) - - // Ensure all permissions are in the token, otherwise return none to call permissions service - _ <- permissions.forall(p => permissionsMap.contains(p.toString)).option(()) - - authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(JsBoolean(true))) - } yield (authorized, token))) - } - - def authorizedByService: OptionT[Future, (Boolean, PermissionsToken)] = - authorization.userHasPermissions(permissions)(ctx).map { - case (permissionMap, token) => - (permissions.forall(p => permissionMap.getOrElse(p, false)), token) - } - - authorizedByToken.orElse(authorizedByService) - } } trait Service @@ -294,9 +320,9 @@ package rest { trait ServiceTransport { - def sendRequestGetResponse(context: RequestContext)(requestStub: HttpRequest): Future[HttpResponse] + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - def sendRequest(context: RequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] } trait ServiceDiscovery { @@ -313,7 +339,7 @@ package rest { protected implicit val materializer = ActorMaterializer()(actorSystem) protected implicit val execution = executionContext - def sendRequestGetResponse(context: RequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { val requestTime = time.currentTime() @@ -341,7 +367,7 @@ package rest { response } - def sendRequest(context: RequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { sendRequestGetResponse(context)(requestStub) map { response => if (response.status == StatusCodes.NotFound) { -- cgit v1.2.3