diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/app/DriverApp.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/app/DriverApp.scala | 67 |
1 files changed, 12 insertions, 55 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 9214cea..50e471c 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -11,7 +11,6 @@ import akka.http.scaladsl.{Http, HttpExt} import akka.stream.ActorMaterializer import com.typesafe.config.Config import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme import kamon.Kamon import kamon.statsd.StatsDReporter import kamon.system.SystemMetrics @@ -24,11 +23,13 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} import xyz.driver.tracing.TracingDirectives._ import xyz.driver.tracing._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext} import scala.util.Try import scalaz.Scalaz.stringInstance import scalaz.syntax.equal._ +import xyz.driver.core.rest.directives.CorsDirectives class DriverApp( appName: String, @@ -69,62 +70,15 @@ class DriverApp( } } - protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { - import scala.collection.JavaConverters._ - config - .getConfigList("application.cors.allowedOrigins") - .asScala - .map { c => - HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) - }(scala.collection.breakOut) - } - - protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { - import scala.collection.JavaConverters._ - config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) - } - - protected lazy val defaultCorsAllowedOrigin: Origin = { - Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) - } - - protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { - val allowedOrigin = - origin - .filter { requestOrigin => - allowedCorsDomainSuffixes.exists { allowedOriginSuffix => - requestOrigin.origins.exists(o => - o.scheme == allowedOriginSuffix.scheme && - o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) - } - } - .getOrElse(defaultCorsAllowedOrigin) - - `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) - } - - protected def respondWithAllCorsHeaders: Directive0 = { - respondWithCorsAllowedHeaders tflatMap { _ => - respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => - optionalHeaderValueByType[Origin](()) flatMap { origin => - respondWithHeader(corsAllowedOriginHeader(origin)) - } - } - } - } - private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def defaultOptionsRoute: Route = options { - respondWithAllCorsHeaders { - complete("OK") - } - } - def appRoute: Route = { - val serviceTypes = modules.flatMap(_.routeTypes) - val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) + val serviceTypes = modules.flatMap(_.routeTypes) + val serviceClasses = serviceTypes.map { tpe => + scala.reflect.runtime.currentMirror.runtimeClass(tpe.typeSymbol.asClass) + }.toSet + val swaggerService = new Swagger(baseUrl, scheme :: Nil, version, serviceClasses, config, log) val swaggerRoute = new DriverRoute { def log: Logger = self.log def route: Route = handleExceptions(ExceptionHandler(exceptionHandler)) { @@ -139,7 +93,6 @@ class DriverApp( val combinedRoute = Route.seal( modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ swaggerRoute.route ~ defaultOptionsRoute) - (extractHost & extractClientIP & trace(tracer) & handleRejections(authenticationRejectionHandler)) { case (origin, ip) => ctx => @@ -173,7 +126,11 @@ class DriverApp( r } - respondWithAllCorsHeaders(logResponses(combinedRoute))(contextWithTrackingId) + val innerRoute = CorsDirectives.cors( + config.getStringList("application.cors.allowedOrigins").asScala.toSet, + AllowedHeaders + )(logResponses(combinedRoute)) + innerRoute(contextWithTrackingId) } } |