From e3268b87bc9446e69b59ed5f3990f42c8a00d918 Mon Sep 17 00:00:00 2001 From: vlad Date: Wed, 2 Nov 2016 13:59:36 -0700 Subject: DIR-135 Directive for more effortless context extraction --- src/main/scala/xyz/driver/core/auth.scala | 12 ++++++------ src/main/scala/xyz/driver/core/rest.scala | 16 +++++++++------- src/test/scala/xyz/driver/core/AuthTest.scala | 19 ++++++++++++------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala index 3dd21d9..e4d726b 100644 --- a/src/main/scala/xyz/driver/core/auth.scala +++ b/src/main/scala/xyz/driver/core/auth.scala @@ -2,6 +2,7 @@ package xyz.driver.core import akka.http.scaladsl.model.headers.HttpChallenges import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import xyz.driver.core.rest.ServiceRequestContext import scala.concurrent.Future import scala.util.{Failure, Success} @@ -73,7 +74,8 @@ object auth { final case class PasswordHash(value: String) object AuthService { - val AuthenticationTokenHeader = "WWW-Authenticate" + val AuthenticationTokenHeader = rest.ContextHeaders.AuthenticationTokenHeader + val SetAuthenticationTokenHeader = "set-authorization" } trait AuthService[U <: User] { @@ -81,13 +83,11 @@ object auth { import akka.http.scaladsl.server._ import Directives._ - protected def authStatus(authToken: AuthToken): OptionT[Future, U] + protected def authStatus(context: ServiceRequestContext): OptionT[Future, U] def authorize(permissions: Permission*): Directive1[U] = { - headerValueByName(AuthService.AuthenticationTokenHeader).flatMap { tokenValue => - val token = AuthToken(tokenValue) - - onComplete(authStatus(token).run).flatMap { + rest.serviceContext flatMap { ctx => + onComplete(authStatus(ctx).run).flatMap { case Success(Some(user)) => if (permissions.forall(user.permissions.contains)) provide(user) else { diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index c52d9e0..18dbcf7 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -4,13 +4,11 @@ import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.RawHeader -import akka.http.scaladsl.server.RequestContext import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.ActorMaterializer import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config -import xyz.driver.core.auth.AuthService import xyz.driver.core.logging.Logger import xyz.driver.core.stats.Stats import xyz.driver.core.time.TimeRange @@ -23,15 +21,19 @@ import scalaz.Scalaz.{Id => _, _} object rest { object ContextHeaders { - val AuthenticationTokenHeader = AuthService.AuthenticationTokenHeader + val AuthenticationTokenHeader = "WWW-Authenticate" val TrackingIdHeader = "l5d-ctx-trace" // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ } final case class ServiceRequestContext(trackingId: String, contextHeaders: Map[String, String]) - def serviceContext(ctx: RequestContext): ServiceRequestContext = { + import akka.http.scaladsl.server._ + import Directives._ + + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx)) + + def extractServiceContext(ctx: RequestContext): ServiceRequestContext = ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) - } def extractTrackingId(ctx: RequestContext): String = { ctx.request.headers @@ -74,13 +76,13 @@ object rest { .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) - log.audit(s"Sending to ${request.uri} request $request") + log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}") val responseEntity = Http()(actorSystem).singleRequest(request)(materializer) map { response => if(response.status == StatusCodes.NotFound) { Unmarshal(HttpEntity.Empty: ResponseEntity) } else if(response.status.isFailure()) { - throw new Exception("Http status is failure " + response.status) + throw new Exception(s"Http status is failure ${response.status}") } else { Unmarshal(response.entity) } diff --git a/src/test/scala/xyz/driver/core/AuthTest.scala b/src/test/scala/xyz/driver/core/AuthTest.scala index ca7e019..e5e991b 100644 --- a/src/test/scala/xyz/driver/core/AuthTest.scala +++ b/src/test/scala/xyz/driver/core/AuthTest.scala @@ -8,6 +8,7 @@ import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsReject import org.scalatest.mock.MockitoSugar import org.scalatest.{FlatSpec, Matchers} import xyz.driver.core.auth._ +import xyz.driver.core.rest.ServiceRequestContext import scala.concurrent.Future import scalaz.OptionT @@ -15,11 +16,15 @@ import scalaz.OptionT class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRouteTest { val authStatusService: AuthService[User] = new AuthService[User] { - override def authStatus(authToken: AuthToken): OptionT[Future, User] = OptionT.optionT[Future] { - Future.successful(Some(new User { - override def id: Id[User] = Id[User](1L) - override def roles: Set[Role] = Set(PathologistRole) - }: User)) + override def authStatus(context: ServiceRequestContext): OptionT[Future, User] = OptionT.optionT[Future] { + if (context.contextHeaders.keySet.contains(AuthService.AuthenticationTokenHeader)) { + Future.successful(Some(new User { + override def id: Id[User] = Id[User](1L) + override def roles: Set[Role] = Set(PathologistRole) + }: User)) + } else { + Future.successful(Option.empty[User]) + } } } @@ -33,8 +38,8 @@ class AuthTest extends FlatSpec with Matchers with MockitoSugar with ScalatestRo complete("Never going to be here") } ~> check { - handled shouldBe false - rejections should contain(MissingHeaderRejection("WWW-Authenticate")) + // handled shouldBe false + rejections should contain(ValidationRejection("Wasn't able to find authenticated user for the token provided")) } } -- cgit v1.2.3