aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core')
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala33
-rw-r--r--src/main/scala/xyz/driver/core/app/module.scala5
-rw-r--r--src/main/scala/xyz/driver/core/rest/DriverRoute.scala57
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala25
4 files changed, 75 insertions, 45 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index d95e254..a593893 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -2,7 +2,6 @@ package xyz.driver.core.app
import akka.actor.ActorSystem
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
-import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
@@ -42,7 +41,6 @@ class DriverApp(
port: Int = 8080,
tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) {
self =>
- import DriverApp._
implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem)
private lazy val http: HttpExt = Http()(actorSystem)
@@ -73,8 +71,9 @@ class DriverApp(
val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI
val versionRt = versionRoute(version, gitHash, time.currentTime())
val basicRoutes = new DriverRoute {
- override def log: Logger = self.log
- override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
+ override def log: Logger = self.log
+ override def config: Config = xyz.driver.core.config.loadDefaultConfig
+ override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
}
val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _)
@@ -221,29 +220,3 @@ class DriverApp(
})
}
}
-
-object DriverApp {
- implicit def rejectionHandler: RejectionHandler =
- RejectionHandler
- .newBuilder()
- .handleAll[MethodRejection] { rejections =>
- val methods = rejections map (_.supported)
- lazy val names = methods map (_.name) mkString ", "
-
- options {
- respondWithCorsHeaders {
- respondWithCorsAllowedMethodHeaders(methods) {
- complete(s"Supported methods: $names.")
- }
- }
- } ~
- complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!")
- }
- .handleAll[Rejection] { rejections =>
- respondWithCorsHeaders {
- reject(rejections: _*)
- }
- }
- .result()
- .seal
-}
diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala
index 7be38eb..0a255fb 100644
--- a/src/main/scala/xyz/driver/core/app/module.scala
+++ b/src/main/scala/xyz/driver/core/app/module.scala
@@ -30,8 +30,9 @@ class EmptyModule extends Module {
class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module {
private val driverRoute: DriverRoute = new DriverRoute {
- override def route: Route = theRoute
- override val log: Logger = xyz.driver.core.logging.NoLogger
+ override def route: Route = theRoute
+ override val config: Config = xyz.driver.core.config.loadDefaultConfig
+ override val log: Logger = xyz.driver.core.logging.NoLogger
}
override def route: Route = driverRoute.routeWithDefaults
diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
index 5f961b6..5647818 100644
--- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
+++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
@@ -7,6 +7,7 @@ import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route}
+import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import org.slf4j.MDC
import xyz.driver.core.rest
@@ -16,17 +17,69 @@ import scala.compat.Platform.ConcurrentModificationException
trait DriverRoute {
def log: Logger
+ def config: Config
def route: Route
def routeWithDefaults: Route = {
- (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route)
+ (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) {
+ route ~ defaultOptionsRoute
+ }
+ }
+
+ protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
+ import scala.collection.JavaConverters._
+ config
+ .getConfigList("application.cors.allowedOrigins")
+ .asScala
+ .map { c =>
+ HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
+ }(scala.collection.breakOut)
+ }
+
+ protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
+ import scala.collection.JavaConverters._
+ config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
+ }
+
+ protected lazy val defaultCorsAllowedOrigin: Origin =
+ Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
+
+ protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
+ val allowedOrigin =
+ origin
+ .filter { requestOrigin =>
+ allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
+ requestOrigin.origins.exists(o =>
+ o.scheme == allowedOriginSuffix.scheme &&
+ o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
+ }
+ }
+ .getOrElse(defaultCorsAllowedOrigin)
+
+ `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
+ }
+
+ protected def respondWithAllCorsHeaders: Directive0 = {
+ respondWithCorsAllowedHeaders tflatMap { _ =>
+ respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
+ optionalHeaderValueByType[Origin](()) flatMap { origin =>
+ respondWithHeader(corsAllowedOriginHeader(origin))
+ }
+ }
+ }
+ }
+
+ protected def defaultOptionsRoute: Route = options {
+ respondWithAllCorsHeaders {
+ complete("OK")
+ }
}
protected def defaultResponseHeaders: Directive0 = {
extractRequest flatMap { request =>
val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
- respondWithHeader(tracingHeader) & respondWithCorsHeaders
+ respondWithHeader(tracingHeader) & respondWithAllCorsHeaders
}
}
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
index 88f78d9..5fd9417 100644
--- a/src/main/scala/xyz/driver/core/rest/package.scala
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -110,22 +110,25 @@ object `package` {
}
}
- def respondWithCorsHeaders: Directive0 = {
- optionalHeaderValueByType[Origin](()) flatMap { originHeader =>
- respondWithHeaders(
- List[HttpHeader](
- allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(AllowedHeaders: _*)
- ))
+ def respondWithCorsAllowedHeaders: Directive0 = {
+ respondWithHeaders(
+ List[HttpHeader](
+ `Access-Control-Allow-Headers`(AllowedHeaders: _*),
+ `Access-Control-Expose-Headers`(AllowedHeaders: _*)
+ ))
+ }
+
+ def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = {
+ respondWithHeader {
+ `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*))
}
}
- def respondWithCorsAllowedMethodHeaders(methods: scala.collection.immutable.Seq[HttpMethod]): Directive0 = {
+ def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = {
respondWithHeaders(
List[HttpHeader](
- Allow(methods),
- `Access-Control-Allow-Methods`(methods)
+ Allow(methods.to[collection.immutable.Seq]),
+ `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq])
))
}