diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/app/DriverApp.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/app/DriverApp.scala | 167 |
1 files changed, 38 insertions, 129 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 901d6e2..751bef7 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -1,35 +1,31 @@ package xyz.driver.core.app -import java.sql.SQLException - import akka.actor.ActorSystem import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError, MethodNotAllowed} +import akka.http.scaladsl.model.StatusCodes._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.RouteResult.route2HandlerFlow +import akka.http.scaladsl.server.RouteResult._ import akka.http.scaladsl.server._ import akka.http.scaladsl.{Http, HttpExt} import akka.stream.ActorMaterializer -import com.github.swagger.akka.SwaggerHttpService.{logger, toJavaTypeSet} +import com.github.swagger.akka.SwaggerHttpService._ import com.typesafe.config.Config import com.typesafe.scalalogging.Logger import io.swagger.models.Scheme import io.swagger.util.Json import org.slf4j.{LoggerFactory, MDC} import xyz.driver.core -import xyz.driver.core.rest -import xyz.driver.core.rest.{ContextHeaders, Swagger} +import xyz.driver.core.rest._ import xyz.driver.core.stats.SystemStats import xyz.driver.core.time.Time import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} -import xyz.driver.tracing.TracingDirectives.trace -import xyz.driver.tracing.{NoTracer, Tracer} +import xyz.driver.tracing.TracingDirectives._ +import xyz.driver.tracing._ -import scala.compat.Platform.ConcurrentModificationException import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{Await, ExecutionContext} import scala.reflect.runtime.universe._ import scala.util.Try import scala.util.control.NonFatal @@ -64,8 +60,7 @@ class DriverApp(appName: String, def stop(): Unit = { http.shutdownAllConnectionPools().onComplete { _ => Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer - val _ = actorSystem.terminate() - val terminated = Await.result(actorSystem.whenTerminated, 30.seconds) + val terminated = Await.result(actorSystem.terminate(), 30.seconds) val addressTerminated = if (terminated.addressTerminated) "is" else "is not" Console.print(s"${this.getClass.getName} App $addressTerminated stopped ") } @@ -74,69 +69,41 @@ class DriverApp(appName: String, private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def bindHttp(modules: Seq[Module]): Unit = { + protected def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = swaggerOverride(serviceTypes) - val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI + val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) + val combinedRoute = modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoute)(_ ~ _) - val _ = Future { - http.bindAndHandle( - route2HandlerFlow(extractHost { origin => - trace(tracer) { - extractClientIP { ip => - optionalHeaderValueByType[Origin](()) { - originHeader => - { - ctx => - val trackingId = rest.extractTrackingId(ctx.request) - MDC.put("trackingId", trackingId) - - val updatedStacktrace = - (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") - MDC.put("stack", updatedStacktrace) - - storeRequestContextToMdc(ctx.request, origin, ip) - - def requestLogging: Future[Unit] = Future { - log.info( - s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") - } - - val contextWithTrackingId = - ctx.withRequest( - ctx.request - .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) - .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - - handleExceptions(ExceptionHandler(exceptionHandler))({ - c => - requestLogging.flatMap { _ => - val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - - val responseHeaders = List[HttpHeader]( - trackingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) - }(c) - } - })(contextWithTrackingId) - } - } - } - } - }), - interface, - port - )(materializer) + (extractHost & extractClientIP & trace(tracer)) { + case (origin, ip) => + ctx => + val trackingId = extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + + val updatedStacktrace = + (extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) + + storeRequestContextToMdc(ctx.request, origin, ip) + + log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") + + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) + + combinedRoute(contextWithTrackingId) } } + protected def bindHttp(modules: Seq[Module]): Unit = { + val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer) + } + private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = { MDC.put("origin", origin) @@ -181,58 +148,6 @@ class DriverApp(appName: String, } } - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) - - case t: Throwable => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) - } - - protected def errorResponse[T <: Throwable](ctx: RequestContext, - statusCode: StatusCode, - message: String, - exception: T): Route = { - - val trackingId = rest.extractTrackingId(ctx.request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) - - MDC.put("trackingId", trackingId) - - optionalHeaderValueByType[Origin](()) { originHeader => - val responseHeaders = List[HttpHeader]( - tracingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - complete(HttpResponse(statusCode, entity = message)) - } - } - } - protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { import spray.json._ import DefaultJsonProtocol._ @@ -334,11 +249,6 @@ class DriverApp(appName: String, } object DriverApp { - - private def allowOrigin(originHeader: Option[Origin]) = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - implicit def rejectionHandler: RejectionHandler = RejectionHandler .newBuilder() @@ -352,8 +262,8 @@ object DriverApp { Allow(methods), `Access-Control-Allow-Methods`(methods), allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) )) { complete(s"Supported methods: $names.") } @@ -362,5 +272,4 @@ object DriverApp { complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!") } .result() - } |