From 5db157103a1982f00ca72e8f6b925344debce36e Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Wed, 25 Oct 2017 13:34:46 -0700 Subject: Move all CORS headers to DriverRoute from DriverApp --- src/main/scala/xyz/driver/core/app/DriverApp.scala | 52 ++++++++-------------- .../scala/xyz/driver/core/rest/DriverRoute.scala | 21 +++++++-- 2 files changed, 36 insertions(+), 37 deletions(-) (limited to 'src/main/scala/xyz') diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 4110c37..2bb61c0 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -73,47 +73,31 @@ class DriverApp(appName: String, protected def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = swaggerOverride(serviceTypes) - val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI + val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) + val combinedRoute = modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoute)(_ ~ _) - extractHost { origin => - extractClientIP { ip => - optionalHeaderValueByType[Origin](()) { originHeader => - trace(tracer) { ctx => - val trackingId = rest.extractTrackingId(ctx.request) - MDC.put("trackingId", trackingId) + (extractHost & extractClientIP & trace(tracer)) { + case (origin, ip) => + ctx => + val trackingId = rest.extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) - val updatedStacktrace = - (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") - MDC.put("stack", updatedStacktrace) + val updatedStacktrace = + (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) - storeRequestContextToMdc(ctx.request, origin, ip) + storeRequestContextToMdc(ctx.request, origin, ip) - val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") - val responseHeaders = List[HttpHeader]( - trackingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - val contextWithTrackingId = - ctx.withRequest( - ctx.request - .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) - .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - - respondWithHeaders(responseHeaders) { - modules - .map(_.route) - .foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) - }(contextWithTrackingId) - } - } - } + combinedRoute(contextWithTrackingId) } } diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala index 79dff8b..d42bd75 100644 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -6,7 +6,7 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.{ExceptionHandler, RequestContext, Route} +import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route} import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.rest @@ -19,8 +19,23 @@ trait DriverRoute { def route: Route - def routeWithDefaults: Route = handleExceptions(ExceptionHandler(exceptionHandler)) { - route + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route) + } + + protected def defaultResponseHeaders: Directive0 = { + (extractRequest & optionalHeaderValueByType[Origin](())) tflatMap { + case (request, originHeader) => + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) + val responseHeaders = List[HttpHeader]( + tracingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + ) + + respondWithHeaders(responseHeaders) + } } /** -- cgit v1.2.3