From 0e1d65445524b9819f701b67e11bebd03121964c Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 11 May 2017 16:13:10 -0700 Subject: Permissions token refactors --- src/main/scala/xyz/driver/core/rest.scala | 166 +++++++++++++++----------- src/test/scala/xyz/driver/core/AuthTest.scala | 25 ++-- 2 files changed, 110 insertions(+), 81 deletions(-) (limited to 'src') diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index bacb120..3f61246 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -27,13 +27,13 @@ import scalaz.{ListT, OptionT} package rest { object `package` { - import akka.http.scaladsl.server.{RequestContext => _, _} + import akka.http.scaladsl.server._ import Directives._ - def serviceContext: Directive1[RequestContext] = extract(ctx => extractServiceContext(ctx.request)) + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - def extractServiceContext(request: HttpRequest): RequestContext = - new RequestContext(extractTrackingId(request), extractContextHeaders(request)) + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers @@ -92,38 +92,39 @@ package rest { } } - class RequestContext(val trackingId: String = generators.nextUuid().toString, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { + class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) def permissionsToken: Option[PermissionsToken] = contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedRequestContext[U] = - new AuthorizedRequestContext(trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user) + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user) override def hashCode(): Int = Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) override def equals(obj: Any): Boolean = obj match { - case ctx: RequestContext => trackingId == ctx.trackingId && contextHeaders == ctx.contextHeaders - case _ => false + case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders + case _ => false } override def toString: String = s"RequestContext($trackingId, $contextHeaders)" } - class AuthorizedRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val contextHeaders: Map[String, String] = - Map.empty[String, String], - val authenticatedUser: U) - extends RequestContext { + class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedRequestContext[U] = - new AuthorizedRequestContext[U]( + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( trackingId, contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), authenticatedUser) @@ -131,8 +132,8 @@ package rest { override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false } override def toString: String = s"AuthenticatedRequestContext($trackingId, $contextHeaders, $authenticatedUser)" @@ -152,24 +153,79 @@ package rest { val SetPermissionsTokenHeader = "set-permissions" } + final case class AuthorizationResult(authorized: Boolean, token: Option[PermissionsToken]) + object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = false, None) + } + trait Authorization[U <: User] { def userHasPermissions(permissions: Seq[Permission])( - implicit ctx: AuthorizedRequestContext[U]): OptionT[Future, - (Map[Permission, Boolean], PermissionsToken)] + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] } - class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + class AlwaysAllowAuthorization[U <: User](implicit execution: ExecutionContext) extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = + Future.successful(AuthorizationResult(authorized = true, ctx.permissionsToken)) + } + + class CachedTokenAuthorization[U <: User](publicKey: PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = { + import spray.json._ + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(ctx.authenticatedUser.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- jwtJson.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(true)) + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } + } + + class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + override def userHasPermissions(permissions: Seq[Permission])( - implicit ctx: AuthorizedRequestContext[U]): OptionT[Future, - (Map[Permission, Boolean], PermissionsToken)] = - OptionT.optionT(Future.successful(Option((permissions.map(_ -> true).toMap, PermissionsToken(""))))) + implicit ctx: AuthorizedServiceRequestContext[U]): Future[AuthorizationResult] = { + def callAuthorizations( + remainingAuthorizations: List[Authorization[U]] = authorizations.toList): Future[AuthorizationResult] = { + remainingAuthorizations match { + case auth :: Nil => auth.userHasPermissions(permissions) + case auth :: rest => + auth + .userHasPermissions(permissions) + .flatMap( + result => + if (result.authorized) Future.successful(result) + else callAuthorizations(rest)) + case Nil => Future.successful(AuthorizationResult.unauthorized) + } + } + + callAuthorizations() + } + } - abstract class AuthProvider[U <: User](val authorization: Authorization[U], - val permissionsTokenPublicKey: PublicKey, - log: Logger)(implicit execution: ExecutionContext) { + abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( + implicit execution: ExecutionContext) { - import akka.http.scaladsl.server.{RequestContext => _, _} + import akka.http.scaladsl.server._ import Directives._ /** @@ -179,20 +235,21 @@ package rest { * @param ctx set of request values which can be relevant to authenticate user * @return authenticated user */ - def authenticatedUser(implicit ctx: RequestContext): OptionT[Future, U] + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] /** * Verifies if request is authenticated and authorized to have `permissions` */ - def authorize(permissions: Permission*): Directive1[AuthorizedRequestContext[U]] = { + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { serviceContext flatMap { ctx => onComplete { (for { authToken <- OptionT.optionT(Future.successful(ctx.authToken)) user <- authenticatedUser(ctx) authCtx = ctx.withAuthenticatedUser(authToken, user) - (authorizationResult, permissionsToken) <- userHasPermission(user, permissions)(authCtx) - } yield (authCtx.withPermissionsToken(permissionsToken), authorizationResult)).run + authorizationResult <- authorization.userHasPermissions(permissions)(authCtx).toOptionT + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + } yield (cachedPermissionsAuthCtx, authorizationResult.authorized)).run } flatMap { case Success(Some((authCtx, true))) => provide(authCtx) case Success(Some((authCtx, false))) => @@ -211,37 +268,6 @@ package rest { } } } - - protected def userHasPermission(user: U, permissions: Seq[Permission])( - ctx: AuthorizedRequestContext[U]): OptionT[Future, (Boolean, PermissionsToken)] = { - import spray.json._ - - def authorizedByToken: OptionT[Future, (Boolean, PermissionsToken)] = { - OptionT.optionT(Future.successful(for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, permissionsTokenPublicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user, otherwise return None to call permissions service - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - - permissionsMap <- jwtJson.fields.get("permissions").map(_.asJsObject.fields) - - // Ensure all permissions are in the token, otherwise return none to call permissions service - _ <- permissions.forall(p => permissionsMap.contains(p.toString)).option(()) - - authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(JsBoolean(true))) - } yield (authorized, token))) - } - - def authorizedByService: OptionT[Future, (Boolean, PermissionsToken)] = - authorization.userHasPermissions(permissions)(ctx).map { - case (permissionMap, token) => - (permissions.forall(p => permissionMap.getOrElse(p, false)), token) - } - - authorizedByToken.orElse(authorizedByService) - } } trait Service @@ -294,9 +320,9 @@ package rest { trait ServiceTransport { - def sendRequestGetResponse(context: RequestContext)(requestStub: HttpRequest): Future[HttpResponse] + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - def sendRequest(context: RequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] } trait ServiceDiscovery { @@ -313,7 +339,7 @@ package rest { protected implicit val materializer = ActorMaterializer()(actorSystem) protected implicit val execution = executionContext - def sendRequestGetResponse(context: RequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { val requestTime = time.currentTime() @@ -341,7 +367,7 @@ package rest { response } - def sendRequest(context: RequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { sendRequestGetResponse(context)(requestStub) map { response => if (response.status == StatusCodes.NotFound) { diff --git a/src/test/scala/xyz/driver/core/AuthTest.scala b/src/test/scala/xyz/driver/core/AuthTest.scala index 9c86577..8de0e87 100644 --- a/src/test/scala/xyz/driver/core/AuthTest.scala +++ b/src/test/scala/xyz/driver/core/AuthTest.scala @@ -3,14 +3,14 @@ package xyz.driver.core import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.{RequestContext => _, _} +import akka.http.scaladsl.server._ import akka.http.scaladsl.testkit.ScalatestRouteTest import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpec, Matchers} import pdi.jwt.{Jwt, JwtAlgorithm} import xyz.driver.core.auth._ import xyz.driver.core.logging._ -import xyz.driver.core.rest.{AuthProvider, AuthorizedRequestContext, Authorization, RequestContext} +import xyz.driver.core.rest._ import scala.concurrent.Future import scalaz.OptionT @@ -33,19 +33,22 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo (keyPair.getPublic, keyPair.getPrivate) } - val authorization: Authorization[User] = new Authorization[User] { + val basicAuthorization: Authorization[User] = new Authorization[User] { override def userHasPermissions(permissions: Seq[Permission])( - implicit ctx: AuthorizedRequestContext[User]): OptionT[Future, - (Map[Permission, Boolean], PermissionsToken)] = { - val permissionsMap = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap - val token = PermissionsToken("TODO") - OptionT.optionT(Future.successful(Option((permissionsMap, token)))) + implicit ctx: AuthorizedServiceRequestContext[User]): Future[AuthorizationResult] = { + val authorized = permissions.forall(_ === TestRoleAllowedPermission) + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) } } - val authStatusService = new AuthProvider[User](authorization, publicKey, NoLogger) { - override def authenticatedUser(implicit ctx: RequestContext): OptionT[Future, User] = + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) + + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + + val authStatusService = new AuthProvider[User](authorization, NoLogger) { + override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = OptionT.optionT[Future] { if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { Future.successful(Some(BasicUser(Id[User]("1"), Set(TestRole)))) @@ -109,7 +112,7 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo val claim = JsObject( Map( - "iss" -> JsString("users"), + "iss" -> JsString(tokenIssuer), "sub" -> JsString("1"), "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) )).prettyPrint -- cgit v1.2.3