From 5d6ef236eb3b5a4fad6383b87e1572685ddf2bc5 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 22 Feb 2018 17:03:26 -0800 Subject: Move cors directives to DriverApp and get rid of custom rejection handler --- src/main/scala/xyz/driver/core/app/DriverApp.scala | 57 +++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) (limited to 'src/main/scala/xyz/driver/core/app') diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 1ded4dd..ca3dd54 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -62,9 +62,59 @@ class DriverApp( } } + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = { + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + } + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } + } + def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) @@ -75,7 +125,8 @@ class DriverApp( override def config: Config = self.config override def route: Route = versionRt ~ healthRoute ~ swaggerRoute } - val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) + val combinedRoute = + Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute) (extractHost & extractClientIP & trace(tracer)) { case (origin, ip) => @@ -97,7 +148,9 @@ class DriverApp( .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - combinedRoute(contextWithTrackingId) + respondWithAllCorsHeaders { + combinedRoute + }(contextWithTrackingId) } } -- cgit v1.2.3