From 8c60e3f6b22e4ee94c5cf7a0cb1f36e1266269de Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Fri, 4 Aug 2017 16:32:40 -0700 Subject: Auth directives and formatting --- .../xyz/driver/pdsuicommon/http/Directives.scala | 63 +++++++++++----------- .../pdsuicommon/parsers/DimensionsParser.scala | 5 +- .../pdsuicommon/parsers/ListRequestParser.scala | 4 +- .../pdsuicommon/parsers/PagiationParser.scala | 11 ++-- .../pdsuicommon/parsers/SearchFilterParser.scala | 5 +- .../driver/pdsuicommon/parsers/SortingParser.scala | 5 +- 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala index d5c6365..0fe1f73 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala @@ -4,13 +4,19 @@ import akka.http.scaladsl.marshalling._ import akka.http.scaladsl.server.Directive1 import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.Route +import akka.http.scaladsl.model._ +import xyz.driver.core.rest.AuthorizedServiceRequestContext +import xyz.driver.core.rest.ContextHeaders +import xyz.driver.entities.users.UserInfo import xyz.driver.pdsuicommon.auth._ import xyz.driver.pdsuicommon.error._ import xyz.driver.pdsuicommon.error.DomainError._ import xyz.driver.pdsuicommon.error.ErrorsResponse.ResponseError import xyz.driver.pdsuicommon.parsers._ import xyz.driver.pdsuicommon.db.{Pagination, Sorting, SearchFilterExpr} + import scala.util._ +import scala.concurrent._ trait Directives { @@ -32,7 +38,7 @@ trait Directives { } } - @annotation.implicitNotFound("An ApiExtractor is required to complete service replies.") + @annotation.implicitNotFound("An ApiExtractor of ${Reply} to ${Api} is required to complete service replies.") trait ApiExtractor[Reply, Api] extends PartialFunction[Reply, Api] object ApiExtractor { // Note: make sure the Reply here is the most common response @@ -45,45 +51,42 @@ trait Directives { } } - def completeService[Reply, Api](reply: => Reply)(implicit requestId: RequestId, - apiExtractor: ApiExtractor[Reply, Api], - apiMarshaller: ToEntityMarshaller[Api], - errorMarshaller: ToEntityMarshaller[ErrorsResponse]): Route = { + implicit def replyMarshaller[Reply, Api]( + implicit ctx: AuthenticatedRequestContext, + apiExtractor: ApiExtractor[Reply, Api], + apiMarshaller: ToEntityMarshaller[Api], + errorMarshaller: ToEntityMarshaller[ErrorsResponse] + ): ToResponseMarshaller[Reply] = { def errorResponse(err: DomainError) = - ErrorsResponse(Seq(ResponseError(None, err.getMessage, ErrorCode.Unspecified)), requestId) + ErrorsResponse(Seq(ResponseError(None, err.getMessage, ErrorCode.Unspecified)), ctx.requestId) - // TODO: rather than completing the bad requests here, we should - // consider throwing a corresponding exception and then handling - // it in an error handler - reply match { - case apiReply if apiExtractor.isDefinedAt(apiReply) => - complete(apiExtractor(reply)) - case err: NotFoundError => - complete(401 -> errorResponse(err)) - case err: AuthenticationError => - complete(401 -> errorResponse(err)) - case err: AuthorizationError => - complete(403 -> errorResponse(err)) - case err: DomainError => - complete(400 -> errorResponse(err)) - case other => - val msg = s"Got unexpected response type in completion directive: ${other.getClass.getSimpleName}" - val res = ErrorsResponse(Seq(ResponseError(None, msg, ErrorCode.Unspecified)), requestId) - complete(500 -> res) + Marshaller[Reply, HttpResponse] { (executionContext: ExecutionContext) => (reply: Reply) => + implicit val ec = executionContext + reply match { + case apiReply if apiExtractor.isDefinedAt(apiReply) => + Marshaller.fromToEntityMarshaller[Api](StatusCodes.OK).apply(apiExtractor(apiReply)) + case err: NotFoundError => + Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.Unauthorized).apply(errorResponse(err)) + case err: AuthorizationError => + Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.Forbidden).apply(errorResponse(err)) + case err: DomainError => + Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.BadRequest).apply(errorResponse(err)) + case other => + val msg = s"Got unexpected response type in completion directive: ${other.getClass.getSimpleName}" + val res = ErrorsResponse(Seq(ResponseError(None, msg, ErrorCode.Unspecified)), ctx.requestId) + Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.InternalServerError).apply(res) + } } } - import xyz.driver.core.rest.AuthorizedServiceRequestContext - import xyz.driver.core.rest.ContextHeaders - import xyz.driver.entities.users.UserInfo - - implicit def authContext(core: AuthorizedServiceRequestContext[UserInfo]): AuthenticatedRequestContext = - new AuthenticatedRequestContext( + implicit class PdsContext(core: AuthorizedServiceRequestContext[UserInfo]) { + def authenticated = new AuthenticatedRequestContext( core.authenticatedUser, RequestId(), core.contextHeaders(ContextHeaders.AuthenticationTokenHeader) ) + } } diff --git a/src/main/scala/xyz/driver/pdsuicommon/parsers/DimensionsParser.scala b/src/main/scala/xyz/driver/pdsuicommon/parsers/DimensionsParser.scala index 29f2363..17c09ed 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/parsers/DimensionsParser.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/parsers/DimensionsParser.scala @@ -10,8 +10,9 @@ object DimensionsParser { @deprecated("play-akka transition", "0") def tryParse(query: Map[String, Seq[String]]): Try[Dimensions] = - tryParse(query.toSeq.flatMap{ case (key, values) => - values.map(value => key -> value) + tryParse(query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) }) def tryParse(query: Seq[(String, String)]): Try[Dimensions] = { diff --git a/src/main/scala/xyz/driver/pdsuicommon/parsers/ListRequestParser.scala b/src/main/scala/xyz/driver/pdsuicommon/parsers/ListRequestParser.scala index 0356784..c3146ce 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/parsers/ListRequestParser.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/parsers/ListRequestParser.scala @@ -12,8 +12,8 @@ class ListRequestParser(validSortingFields: Set[String]) { def tryParse(request: Request[AnyContent]): Try[ListRequestParameters] = { for { queryFilters <- SearchFilterParser.parse(request.queryString) - sorting <- SortingParser.parse(validSortingFields, request.queryString) - pagination <- PaginationParser.parse(request.queryString) + sorting <- SortingParser.parse(validSortingFields, request.queryString) + pagination <- PaginationParser.parse(request.queryString) } yield ListRequestParameters(queryFilters, sorting, pagination) } diff --git a/src/main/scala/xyz/driver/pdsuicommon/parsers/PagiationParser.scala b/src/main/scala/xyz/driver/pdsuicommon/parsers/PagiationParser.scala index 3381542..3988668 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/parsers/PagiationParser.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/parsers/PagiationParser.scala @@ -7,16 +7,17 @@ object PaginationParser { @deprecated("play-akka transition", "0") def parse(query: Map[String, Seq[String]]): Try[Pagination] = - parse(query.toSeq.flatMap{ case (key, values) => - values.map(value => key -> value) + parse(query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) }) def parse(query: Seq[(String, String)]): Try[Pagination] = { val IntString = """\d+""".r - def validate(field: String) = query.collectFirst{case (`field`, size) => size} match { + def validate(field: String) = query.collectFirst { case (`field`, size) => size } match { case Some(IntString(x)) => x.toInt - case Some(str) => throw new ParseQueryArgException((field, s"must be an integer (found $str)")) - case None => throw new ParseQueryArgException((field, "must be defined")) + case Some(str) => throw new ParseQueryArgException((field, s"must be an integer (found $str)")) + case None => throw new ParseQueryArgException((field, "must be defined")) } Try { diff --git a/src/main/scala/xyz/driver/pdsuicommon/parsers/SearchFilterParser.scala b/src/main/scala/xyz/driver/pdsuicommon/parsers/SearchFilterParser.scala index 061f2ef..768e5f5 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/parsers/SearchFilterParser.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/parsers/SearchFilterParser.scala @@ -110,8 +110,9 @@ object SearchFilterParser { @deprecated("play-akka transition", "0") def parse(query: Map[String, Seq[String]]): Try[SearchFilterExpr] = - parse(query.toSeq.flatMap{ case (key, values) => - values.map(value => key -> value) + parse(query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) }) def parse(query: Seq[(String, String)]): Try[SearchFilterExpr] = Try { diff --git a/src/main/scala/xyz/driver/pdsuicommon/parsers/SortingParser.scala b/src/main/scala/xyz/driver/pdsuicommon/parsers/SortingParser.scala index cc6ade3..c1c332f 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/parsers/SortingParser.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/parsers/SortingParser.scala @@ -32,8 +32,9 @@ object SortingParser { @deprecated("play-akka transition", "0") def parse(validDimensions: Set[String], query: Map[String, Seq[String]]): Try[Sorting] = - parse(validDimensions, query.toSeq.flatMap{ case (key, values) => - values.map(value => key -> value) + parse(validDimensions, query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) }) def parse(validDimensions: Set[String], query: Seq[(String, String)]): Try[Sorting] = Try { -- cgit v1.2.3