From 14c6ae3bcdc1560e91d0443ede592bf0ae876674 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Mon, 7 Aug 2017 19:26:54 -0700 Subject: Exception-based error handling --- .../xyz/driver/pdsuicommon/http/Directives.scala | 92 +++++++++++----------- .../services/rest/RestTrialService.scala | 2 +- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala index 7a6266f..a9d7b38 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala @@ -1,11 +1,8 @@ package xyz.driver.pdsuicommon.http -import akka.http.scaladsl.marshalling._ -import akka.http.scaladsl.server.Directive1 +import akka.http.scaladsl.server._ import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.{PathMatcher1, PathMatchers} import akka.http.scaladsl.model._ -import xyz.driver.core.rest.AuthorizedServiceRequestContext import xyz.driver.core.rest.ContextHeaders import xyz.driver.entities.users.UserInfo import xyz.driver.pdsuicommon.auth._ @@ -15,9 +12,10 @@ import xyz.driver.pdsuicommon.error.ErrorsResponse.ResponseError import xyz.driver.pdsuicommon.parsers._ import xyz.driver.pdsuicommon.db.{Pagination, Sorting, SearchFilterExpr} import xyz.driver.pdsuicommon.domain._ +import xyz.driver.pdsuicommon.serialization.PlayJsonSupport._ +import xyz.driver.core.rest.AuthProvider import scala.util._ -import scala.concurrent._ trait Directives { @@ -25,13 +23,20 @@ trait Directives { case (size, number) => Pagination(size, number) } - def sorted(validDimensions: Set[String]): Directive1[Sorting] = parameterSeq.flatMap { params => + def sorted(validDimensions: Set[String] = Set.empty): Directive1[Sorting] = parameterSeq.flatMap { params => SortingParser.parse(validDimensions, params) match { case Success(sorting) => provide(sorting) case Failure(ex) => failWith(ex) } } + val dimensioned: Directive1[Dimensions] = parameterSeq.flatMap { params => + DimensionsParser.tryParse(params) match { + case Success(dims) => provide(dims) + case Failure(ex) => failWith(ex) + } + } + val searchFiltered: Directive1[SearchFilterExpr] = parameterSeq.flatMap { params => SearchFilterParser.parse(params) match { case Success(sorting) => provide(sorting) @@ -48,54 +53,45 @@ trait Directives { def UuidIdInPath[T]: PathMatcher1[UuidId[T]] = PathMatchers.JavaUUID.map((id) => UuidId(id)) - @annotation.implicitNotFound("An ApiExtractor of ${Reply} to ${Api} is required to complete service replies.") - trait ApiExtractor[Reply, Api] extends PartialFunction[Reply, Api] - object ApiExtractor { - // Note: make sure the Reply here is the most common response - // type. The specific entity type should be handled in the partial - // function. E.g. `apply[GetByIdReply, Api]{case - // GetByIdReply.Entity => Api}` - def apply[Reply, Api](pf: PartialFunction[Reply, Api]): ApiExtractor[Reply, Api] = new ApiExtractor[Reply, Api] { - override def isDefinedAt(x: Reply) = pf.isDefinedAt(x) - override def apply(x: Reply) = pf.apply(x) - } + def failFast[A](reply: A): A = reply match { + case err: NotFoundError => throw new NotFoundException(err.getMessage) + case err: AuthenticationError => throw new AuthenticationException(err.getMessage) + case err: AuthorizationError => throw new AuthorizationException(err.getMessage) + case err: DomainError => throw new DomainException(err.getMessage) + case other => other } - implicit def replyMarshaller[Reply, Api]( - implicit ctx: AuthenticatedRequestContext, - apiExtractor: ApiExtractor[Reply, Api], - apiMarshaller: ToEntityMarshaller[Api], - errorMarshaller: ToEntityMarshaller[ErrorsResponse] - ): ToResponseMarshaller[Reply] = { - + def domainExceptionHandler(req: RequestId) = { def errorResponse(err: DomainError) = - ErrorsResponse(Seq(ResponseError(None, err.getMessage, ErrorCode.Unspecified)), ctx.requestId) - - Marshaller[Reply, HttpResponse] { (executionContext: ExecutionContext) => (reply: Reply) => - implicit val ec = executionContext - reply match { - case apiReply if apiExtractor.isDefinedAt(apiReply) => - Marshaller.fromToEntityMarshaller[Api](StatusCodes.OK).apply(apiExtractor(apiReply)) - case err: NotFoundError => - Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.Unauthorized).apply(errorResponse(err)) - case err: AuthorizationError => - Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.Forbidden).apply(errorResponse(err)) - case err: DomainError => - Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.BadRequest).apply(errorResponse(err)) - case other => - val msg = s"Got unexpected response type in completion directive: ${other.getClass.getSimpleName}" - val res = ErrorsResponse(Seq(ResponseError(None, msg, ErrorCode.Unspecified)), ctx.requestId) - Marshaller.fromToEntityMarshaller[ErrorsResponse](StatusCodes.InternalServerError).apply(res) - } + ErrorsResponse(Seq(ResponseError(None, err.getMessage, ErrorCode.Unspecified)), req) + ExceptionHandler { + case err: AuthenticationError => complete(StatusCodes.Unauthorized -> errorResponse(err)) + case err: AuthorizationError => complete(StatusCodes.Forbidden -> errorResponse(err)) + case err: NotFoundError => complete(StatusCodes.NotFound -> errorResponse(err)) + case err: DomainError => complete(StatusCodes.BadRequest -> errorResponse(err)) } } - implicit class PdsContext(core: AuthorizedServiceRequestContext[UserInfo]) { - def authenticated = new AuthenticatedRequestContext( - core.authenticatedUser, - RequestId(), - core.contextHeaders(ContextHeaders.AuthenticationTokenHeader) - ) + val tracked: Directive1[RequestId] = optionalHeaderValueByName(ContextHeaders.TrackingIdHeader) flatMap { + case Some(id) => provide(RequestId(id)) + case None => provide(RequestId()) + } + + val handleDomainExceptions: Directive0 = tracked.flatMap { + case id => + handleExceptions(domainExceptionHandler(id)) + } + + implicit class AuthProviderWrapper(provider: AuthProvider[UserInfo]) { + val authenticate: Directive1[AuthenticatedRequestContext] = (provider.authorize() & tracked) tflatMap { + case (core, requestId) => + provide( + new AuthenticatedRequestContext( + core.authenticatedUser, + requestId, + core.contextHeaders(ContextHeaders.AuthenticationTokenHeader) + )) + } } } diff --git a/src/main/scala/xyz/driver/pdsuidomain/services/rest/RestTrialService.scala b/src/main/scala/xyz/driver/pdsuidomain/services/rest/RestTrialService.scala index a68cb52..f826b98 100644 --- a/src/main/scala/xyz/driver/pdsuidomain/services/rest/RestTrialService.scala +++ b/src/main/scala/xyz/driver/pdsuidomain/services/rest/RestTrialService.scala @@ -46,7 +46,7 @@ class RestTrialService(transport: ServiceTransport, baseUri: Uri)(implicit prote } def getPdfSource(trialId: StringId[Trial])( - implicit requestContext: AuthenticatedRequestContext): Future[Source[ByteString, NotUsed]] = { + implicit requestContext: AuthenticatedRequestContext): Future[Source[ByteString, NotUsed]] = { val request = HttpRequest(HttpMethods.GET, endpointUri(baseUri, s"/v1/trial/${trialId}/source")) for { response <- transport.sendRequestGetResponse(requestContext)(request) -- cgit v1.2.3