From 5d6ef236eb3b5a4fad6383b87e1572685ddf2bc5 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 22 Feb 2018 17:03:26 -0800 Subject: Move cors directives to DriverApp and get rid of custom rejection handler --- src/main/scala/xyz/driver/core/app/DriverApp.scala | 57 +++++++++++++++++- .../scala/xyz/driver/core/rest/DriverRoute.scala | 67 +--------------------- 2 files changed, 58 insertions(+), 66 deletions(-) (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 1ded4dd..ca3dd54 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -62,9 +62,59 @@ class DriverApp( } } + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = { + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + } + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } + } + def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) @@ -75,7 +125,8 @@ class DriverApp( override def config: Config = self.config override def route: Route = versionRt ~ healthRoute ~ swaggerRoute } - val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) + val combinedRoute = + Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute) (extractHost & extractClientIP & trace(tracer)) { case (origin, ip) => @@ -97,7 +148,9 @@ class DriverApp( .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - combinedRoute(contextWithTrackingId) + respondWithAllCorsHeaders { + combinedRoute + }(contextWithTrackingId) } } diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala index 32d996a..15da808 100644 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -22,79 +22,18 @@ trait DriverRoute { def route: Route def routeWithDefaults: Route = { - (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) { - route ~ defaultOptionsRoute - } - } - - protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { - import scala.collection.JavaConverters._ - config - .getConfigList("application.cors.allowedOrigins") - .asScala - .map { c => - HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) - }(scala.collection.breakOut) - } - - protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { - import scala.collection.JavaConverters._ - config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) - } - - protected lazy val defaultCorsAllowedOrigin: Origin = { - Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) - } - - protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { - val allowedOrigin = - origin - .filter { requestOrigin => - allowedCorsDomainSuffixes.exists { allowedOriginSuffix => - requestOrigin.origins.exists(o => - o.scheme == allowedOriginSuffix.scheme && - o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) - } - } - .getOrElse(defaultCorsAllowedOrigin) - - `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) - } - - protected def respondWithAllCorsHeaders: Directive0 = { - respondWithCorsAllowedHeaders tflatMap { _ => - respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => - optionalHeaderValueByType[Origin](()) flatMap { origin => - respondWithHeader(corsAllowedOriginHeader(origin)) - } - } - } - } - - protected def defaultOptionsRoute: Route = options { - respondWithAllCorsHeaders { - complete("OK") + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route } } protected def defaultResponseHeaders: Directive0 = { extractRequest flatMap { request => val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) - respondWithHeader(tracingHeader) & respondWithAllCorsHeaders + respondWithHeader(tracingHeader) } } - protected def rejectionHandler: RejectionHandler = - RejectionHandler - .newBuilder() - .handle { - case rejection => - respondWithAllCorsHeaders { - RejectionHandler.default(collection.immutable.Seq(rejection)).get - } - } - .result() - /** * Override me for custom exception handling * -- cgit v1.2.3