From 48cb3aa145a883cb7a3bb6d6c8edd23af7dda486 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 16 Aug 2017 19:02:59 -0700 Subject: Add rejection handling --- .../xyz/driver/pdsuicommon/http/Directives.scala | 49 +++++++++++++++------- 1 file changed, 34 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala index feb224a..3f81b8d 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/http/Directives.scala @@ -14,6 +14,7 @@ import xyz.driver.pdsuicommon.db.{Pagination, Sorting, SearchFilterExpr} import xyz.driver.pdsuicommon.domain._ import xyz.driver.pdsuicommon.serialization.PlayJsonSupport._ import xyz.driver.core.rest.AuthProvider +import scala.util.control._ import scala.util._ @@ -22,28 +23,32 @@ trait Directives { val paginated: Directive1[Pagination] = parameterSeq.flatMap { params => PaginationParser.parse(params) match { case Success(pagination) => provide(pagination) - case Failure(ex) => failWith(ex) + case Failure(ex) => + reject(new ValidationRejection("invalid pagination parameter", Some(ex))) } } def sorted(validDimensions: Set[String] = Set.empty): Directive1[Sorting] = parameterSeq.flatMap { params => SortingParser.parse(validDimensions, params) match { case Success(sorting) => provide(sorting) - case Failure(ex) => failWith(ex) + case Failure(ex) => + reject(new ValidationRejection("invalid sorting parameter", Some(ex))) } } val dimensioned: Directive1[Dimensions] = parameterSeq.flatMap { params => DimensionsParser.tryParse(params) match { case Success(dims) => provide(dims) - case Failure(ex) => failWith(ex) + case Failure(ex) => + reject(new ValidationRejection("invalid dimension parameter", Some(ex))) } } val searchFiltered: Directive1[SearchFilterExpr] = parameterSeq.flatMap { params => SearchFilterParser.parse(params) match { case Success(sorting) => provide(sorting) - case Failure(ex) => failWith(ex) + case Failure(ex) => + reject(new ValidationRejection("invalid filter parameter", Some(ex))) } } @@ -64,14 +69,29 @@ trait Directives { case other => other } - def domainExceptionHandler(req: RequestId) = { - def errorResponse(err: DomainError) = - ErrorsResponse(Seq(ResponseError(None, err.getMessage, ErrorCode.Unspecified)), req) + def domainExceptionHandler(req: RequestId): ExceptionHandler = { + def errorResponse(ex: Throwable) = + ErrorsResponse(Seq(ResponseError(None, ex.getMessage, ErrorCode.Unspecified)), req) ExceptionHandler { - case err: AuthenticationError => complete(StatusCodes.Unauthorized -> errorResponse(err)) - case err: AuthorizationError => complete(StatusCodes.Forbidden -> errorResponse(err)) - case err: NotFoundError => complete(StatusCodes.NotFound -> errorResponse(err)) - case err: DomainError => complete(StatusCodes.BadRequest -> errorResponse(err)) + case ex: AuthenticationException => complete(StatusCodes.Unauthorized -> errorResponse(ex)) + case ex: AuthorizationException => complete(StatusCodes.Forbidden -> errorResponse(ex)) + case ex: NotFoundException => complete(StatusCodes.NotFound -> errorResponse(ex)) + case ex: DomainException => complete(StatusCodes.BadRequest -> errorResponse(ex)) + case NonFatal(ex) => complete(StatusCodes.InternalServerError -> errorResponse(ex)) + } + } + + def domainRejectionHandler(req: RequestId): RejectionHandler = { + def wrapContent(message: String) = { + import play.api.libs.json._ + val err = ErrorsResponse(Seq(ResponseError(None, message, ErrorCode.Unspecified)), req) + val text = Json.stringify(implicitly[Writes[ErrorsResponse]].writes(err)) + HttpEntity(ContentTypes.`application/json`, text) + } + RejectionHandler.default.mapRejectionResponse { + case res @ HttpResponse(_, _, ent: HttpEntity.Strict, _) => + res.copy(entity = wrapContent(ent.data.utf8String)) + case x => x // pass through all other types of responses } } @@ -80,13 +100,12 @@ trait Directives { case None => provide(RequestId()) } - val handleDomainExceptions: Directive0 = tracked.flatMap { - case id => - handleExceptions(domainExceptionHandler(id)) + val domainResponse: Directive0 = tracked.flatMap { id => + handleExceptions(domainExceptionHandler(id)) & handleRejections(domainRejectionHandler(id)) } implicit class AuthProviderWrapper(provider: AuthProvider[AuthUserInfo]) { - val authenticate: Directive1[AuthenticatedRequestContext] = (provider.authorize() & tracked) tflatMap { + val authenticated: Directive1[AuthenticatedRequestContext] = (provider.authorize() & tracked) tflatMap { case (core, requestId) => provide( new AuthenticatedRequestContext( -- cgit v1.2.3