aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/app/DriverApp.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/app/DriverApp.scala')
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala67
1 files changed, 12 insertions, 55 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index 9214cea..50e471c 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -11,7 +11,6 @@ import akka.http.scaladsl.{Http, HttpExt}
import akka.stream.ActorMaterializer
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
-import io.swagger.models.Scheme
import kamon.Kamon
import kamon.statsd.StatsDReporter
import kamon.system.SystemMetrics
@@ -24,11 +23,13 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
import xyz.driver.tracing.TracingDirectives._
import xyz.driver.tracing._
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext}
import scala.util.Try
import scalaz.Scalaz.stringInstance
import scalaz.syntax.equal._
+import xyz.driver.core.rest.directives.CorsDirectives
class DriverApp(
appName: String,
@@ -69,62 +70,15 @@ class DriverApp(
}
}
- protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
- import scala.collection.JavaConverters._
- config
- .getConfigList("application.cors.allowedOrigins")
- .asScala
- .map { c =>
- HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
- }(scala.collection.breakOut)
- }
-
- protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
- import scala.collection.JavaConverters._
- config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
- }
-
- protected lazy val defaultCorsAllowedOrigin: Origin = {
- Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
- }
-
- protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
- val allowedOrigin =
- origin
- .filter { requestOrigin =>
- allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
- requestOrigin.origins.exists(o =>
- o.scheme == allowedOriginSuffix.scheme &&
- o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
- }
- }
- .getOrElse(defaultCorsAllowedOrigin)
-
- `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
- }
-
- protected def respondWithAllCorsHeaders: Directive0 = {
- respondWithCorsAllowedHeaders tflatMap { _ =>
- respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
- optionalHeaderValueByType[Origin](()) flatMap { origin =>
- respondWithHeader(corsAllowedOriginHeader(origin))
- }
- }
- }
- }
-
private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
request.headers.find(_.name().toLowerCase === headerName).map(_.value())
- protected def defaultOptionsRoute: Route = options {
- respondWithAllCorsHeaders {
- complete("OK")
- }
- }
-
def appRoute: Route = {
- val serviceTypes = modules.flatMap(_.routeTypes)
- val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log)
+ val serviceTypes = modules.flatMap(_.routeTypes)
+ val serviceClasses = serviceTypes.map { tpe =>
+ scala.reflect.runtime.currentMirror.runtimeClass(tpe.typeSymbol.asClass)
+ }.toSet
+ val swaggerService = new Swagger(baseUrl, scheme :: Nil, version, serviceClasses, config, log)
val swaggerRoute = new DriverRoute {
def log: Logger = self.log
def route: Route = handleExceptions(ExceptionHandler(exceptionHandler)) {
@@ -139,7 +93,6 @@ class DriverApp(
val combinedRoute =
Route.seal(
modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ swaggerRoute.route ~ defaultOptionsRoute)
-
(extractHost & extractClientIP & trace(tracer) & handleRejections(authenticationRejectionHandler)) {
case (origin, ip) =>
ctx =>
@@ -173,7 +126,11 @@ class DriverApp(
r
}
- respondWithAllCorsHeaders(logResponses(combinedRoute))(contextWithTrackingId)
+ val innerRoute = CorsDirectives.cors(
+ config.getStringList("application.cors.allowedOrigins").asScala.toSet,
+ AllowedHeaders
+ )(logResponses(combinedRoute))
+ innerRoute(contextWithTrackingId)
}
}