aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/app/DriverApp.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/app/DriverApp.scala')
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala57
1 files changed, 55 insertions, 2 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index 1ded4dd..ca3dd54 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -62,9 +62,59 @@ class DriverApp(
}
}
+ protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
+ import scala.collection.JavaConverters._
+ config
+ .getConfigList("application.cors.allowedOrigins")
+ .asScala
+ .map { c =>
+ HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
+ }(scala.collection.breakOut)
+ }
+
+ protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
+ import scala.collection.JavaConverters._
+ config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
+ }
+
+ protected lazy val defaultCorsAllowedOrigin: Origin = {
+ Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
+ }
+
+ protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
+ val allowedOrigin =
+ origin
+ .filter { requestOrigin =>
+ allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
+ requestOrigin.origins.exists(o =>
+ o.scheme == allowedOriginSuffix.scheme &&
+ o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
+ }
+ }
+ .getOrElse(defaultCorsAllowedOrigin)
+
+ `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
+ }
+
+ protected def respondWithAllCorsHeaders: Directive0 = {
+ respondWithCorsAllowedHeaders tflatMap { _ =>
+ respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
+ optionalHeaderValueByType[Origin](()) flatMap { origin =>
+ respondWithHeader(corsAllowedOriginHeader(origin))
+ }
+ }
+ }
+ }
+
private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
request.headers.find(_.name().toLowerCase === headerName).map(_.value())
+ protected def defaultOptionsRoute: Route = options {
+ respondWithAllCorsHeaders {
+ complete("OK")
+ }
+ }
+
def appRoute: Route = {
val serviceTypes = modules.flatMap(_.routeTypes)
val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log)
@@ -75,7 +125,8 @@ class DriverApp(
override def config: Config = self.config
override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
}
- val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _)
+ val combinedRoute =
+ Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute)
(extractHost & extractClientIP & trace(tracer)) {
case (origin, ip) =>
@@ -97,7 +148,9 @@ class DriverApp(
.addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
.addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))
- combinedRoute(contextWithTrackingId)
+ respondWithAllCorsHeaders {
+ combinedRoute
+ }(contextWithTrackingId)
}
}