From edbfe3d11eefe10f6d45752d1132e7349e1c6750 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 28 Sep 2017 10:28:55 -0700 Subject: Add DriverRoute trait and clean up DriverApp --- src/main/scala/xyz/driver/core/app/DriverApp.scala | 156 ++++++--------------- src/main/scala/xyz/driver/core/app/module.scala | 32 +++-- .../scala/xyz/driver/core/rest/DriverRoute.scala | 84 +++++++++++ .../xyz/driver/core/rest/errors/APIError.scala | 16 +++ 4 files changed, 164 insertions(+), 124 deletions(-) create mode 100644 src/main/scala/xyz/driver/core/rest/DriverRoute.scala create mode 100644 src/main/scala/xyz/driver/core/rest/errors/APIError.scala (limited to 'src/main/scala/xyz') diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 901d6e2..cb3a38e 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -1,10 +1,8 @@ package xyz.driver.core.app -import java.sql.SQLException - import akka.actor.ActorSystem import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError, MethodNotAllowed} +import akka.http.scaladsl.model.StatusCodes.MethodNotAllowed import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ @@ -27,9 +25,8 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} import xyz.driver.tracing.TracingDirectives.trace import xyz.driver.tracing.{NoTracer, Tracer} -import scala.compat.Platform.ConcurrentModificationException import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{Await, ExecutionContext} import scala.reflect.runtime.universe._ import scala.util.Try import scala.util.control.NonFatal @@ -74,69 +71,58 @@ class DriverApp(appName: String, private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def bindHttp(modules: Seq[Module]): Unit = { + protected def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = swaggerOverride(serviceTypes) val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) - val _ = Future { - http.bindAndHandle( - route2HandlerFlow(extractHost { origin => - trace(tracer) { - extractClientIP { ip => - optionalHeaderValueByType[Origin](()) { - originHeader => - { - ctx => - val trackingId = rest.extractTrackingId(ctx.request) - MDC.put("trackingId", trackingId) - - val updatedStacktrace = - (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") - MDC.put("stack", updatedStacktrace) - - storeRequestContextToMdc(ctx.request, origin, ip) - - def requestLogging: Future[Unit] = Future { - log.info( - s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") - } - - val contextWithTrackingId = - ctx.withRequest( - ctx.request - .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) - .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - - handleExceptions(ExceptionHandler(exceptionHandler))({ - c => - requestLogging.flatMap { _ => - val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - - val responseHeaders = List[HttpHeader]( - trackingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) - }(c) - } - })(contextWithTrackingId) - } - } - } + extractHost { origin => + extractClientIP { ip => + optionalHeaderValueByType[Origin](()) { originHeader => + trace(tracer) { ctx => + val trackingId = rest.extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + + val updatedStacktrace = + (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) + + storeRequestContextToMdc(ctx.request, origin, ip) + + val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + + val responseHeaders = List[HttpHeader]( + trackingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), + `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + ) + + log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") + + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) + + respondWithHeaders(responseHeaders) { + modules + .flatMap(_.routes) + .map(_.routeWithDefaults) + .foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) + }(contextWithTrackingId) } - }), - interface, - port - )(materializer) + } + } } } + protected def bindHttp(modules: Seq[Module]): Unit = { + val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer) + } + private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = { MDC.put("origin", origin) @@ -181,58 +167,6 @@ class DriverApp(appName: String, } } - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) - - case t: Throwable => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) - } - - protected def errorResponse[T <: Throwable](ctx: RequestContext, - statusCode: StatusCode, - message: String, - exception: T): Route = { - - val trackingId = rest.extractTrackingId(ctx.request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) - - MDC.put("trackingId", trackingId) - - optionalHeaderValueByType[Origin](()) { originHeader => - val responseHeaders = List[HttpHeader]( - tracingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - complete(HttpResponse(statusCode, entity = message)) - } - } - } - protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { import spray.json._ import DefaultJsonProtocol._ diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala index c6f979f..6baa457 100644 --- a/src/main/scala/xyz/driver/core/app/module.scala +++ b/src/main/scala/xyz/driver/core/app/module.scala @@ -3,13 +3,14 @@ package xyz.driver.core.app import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.server.Directives.complete import akka.http.scaladsl.server.{Route, RouteConcatenation} -import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} +import com.typesafe.scalalogging.Logger +import xyz.driver.core.rest.{DriverRoute, NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} import scala.reflect.runtime.universe._ trait Module { val name: String - def route: Route + def routes: Seq[DriverRoute] def routeTypes: Seq[Type] val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() @@ -21,13 +22,22 @@ trait Module { class EmptyModule extends Module { override val name: String = "Nothing" - override def route: Route = complete(StatusCodes.OK) + override def routes: Seq[DriverRoute] = + Seq(new DriverRoute { + override def route: Route = complete(StatusCodes.OK) + override val log: Logger = xyz.driver.core.logging.NoLogger + }) override def routeTypes: Seq[Type] = Seq.empty[Type] } -class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module { - def routeTypes: Seq[Type] = Seq(routeType) +class SimpleModule(override val name: String, route: Route, routeType: Type) extends Module { self => + override def routes: Seq[DriverRoute] = + Seq(new DriverRoute { + override def route: Route = self.route + override val log: Logger = xyz.driver.core.logging.NoLogger + }) + override def routeTypes: Seq[Type] = Seq(routeType) } /** @@ -39,12 +49,8 @@ class SimpleModule(override val name: String, override val route: Route, routeTy * @param modules modules to compose into a single one */ class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { - - override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) - - override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) - - override def activate(): Unit = modules.foreach(_.activate()) - - override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) + override def routes: Seq[DriverRoute] = modules.flatMap(_.routes) + override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) + override def activate(): Unit = modules.foreach(_.activate()) + override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) } diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..20cc556 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,84 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.{ExceptionHandler, RequestContext, Route} +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors.APIError + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + val log: Logger + + def route: Route + + def routeWithDefaults: Route = handleExceptions(ExceptionHandler(exceptionHandler)) { + route + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case api: APIError if api.isPatientSensitive => + ctx => + log.info("PHI Sensitive error") + errorResponse(ctx, InternalServerError, "Server error", api)(ctx) + + case api: APIError => + ctx => + log.info("API Error") + errorResponse(ctx, api.statusCode, api.message, api)(ctx) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) + + case t: Throwable => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) + } + + protected def errorResponse[T <: Throwable](ctx: RequestContext, + statusCode: StatusCode, + message: String, + exception: T): Route = { + + val trackingId = rest.extractTrackingId(ctx.request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) + + MDC.put("trackingId", trackingId) + + optionalHeaderValueByType[Origin](()) { originHeader => + val responseHeaders = List[HttpHeader](tracingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*)) + + respondWithHeaders(responseHeaders) { + complete(HttpResponse(statusCode, entity = message)) + } + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/errors/APIError.scala b/src/main/scala/xyz/driver/core/rest/errors/APIError.scala new file mode 100644 index 0000000..f2bfae1 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/errors/APIError.scala @@ -0,0 +1,16 @@ +package xyz.driver.core.rest.errors + +import akka.http.scaladsl.model.{StatusCode, StatusCodes} + +abstract class APIError extends Throwable { + def isPatientSensitive: Boolean = false + + def statusCode: StatusCode + def message: String +} + +final case class InvalidInputError(override val message: String = "Invalid input", + override val isPatientSensitive: Boolean = false) + extends APIError { + override def statusCode: StatusCode = StatusCodes.BadRequest +} -- cgit v1.2.3