diff options
Diffstat (limited to 'src/main/scala/xyz/driver')
-rw-r--r-- | src/main/scala/xyz/driver/core/app/DriverApp.scala | 33 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/app/module.scala | 5 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/DriverRoute.scala | 57 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/package.scala | 25 |
4 files changed, 75 insertions, 45 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index d95e254..a593893 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -2,7 +2,6 @@ package xyz.driver.core.app import akka.actor.ActorSystem import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ @@ -42,7 +41,6 @@ class DriverApp( port: Int = 8080, tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) { self => - import DriverApp._ implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) private lazy val http: HttpExt = Http()(actorSystem) @@ -73,8 +71,9 @@ class DriverApp( val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) val basicRoutes = new DriverRoute { - override def log: Logger = self.log - override def route: Route = versionRt ~ healthRoute ~ swaggerRoute + override def log: Logger = self.log + override def config: Config = xyz.driver.core.config.loadDefaultConfig + override def route: Route = versionRt ~ healthRoute ~ swaggerRoute } val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) @@ -221,29 +220,3 @@ class DriverApp( }) } } - -object DriverApp { - implicit def rejectionHandler: RejectionHandler = - RejectionHandler - .newBuilder() - .handleAll[MethodRejection] { rejections => - val methods = rejections map (_.supported) - lazy val names = methods map (_.name) mkString ", " - - options { - respondWithCorsHeaders { - respondWithCorsAllowedMethodHeaders(methods) { - complete(s"Supported methods: $names.") - } - } - } ~ - complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!") - } - .handleAll[Rejection] { rejections => - respondWithCorsHeaders { - reject(rejections: _*) - } - } - .result() - .seal -} diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala index 7be38eb..0a255fb 100644 --- a/src/main/scala/xyz/driver/core/app/module.scala +++ b/src/main/scala/xyz/driver/core/app/module.scala @@ -30,8 +30,9 @@ class EmptyModule extends Module { class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module { private val driverRoute: DriverRoute = new DriverRoute { - override def route: Route = theRoute - override val log: Logger = xyz.driver.core.logging.NoLogger + override def route: Route = theRoute + override val config: Config = xyz.driver.core.config.loadDefaultConfig + override val log: Logger = xyz.driver.core.logging.NoLogger } override def route: Route = driverRoute.routeWithDefaults diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala index 5f961b6..5647818 100644 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -7,6 +7,7 @@ import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route} +import com.typesafe.config.Config import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.rest @@ -16,17 +17,69 @@ import scala.compat.Platform.ConcurrentModificationException trait DriverRoute { def log: Logger + def config: Config def route: Route def routeWithDefaults: Route = { - (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route) + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route ~ defaultOptionsRoute + } + } + + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } } protected def defaultResponseHeaders: Directive0 = { extractRequest flatMap { request => val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) - respondWithHeader(tracingHeader) & respondWithCorsHeaders + respondWithHeader(tracingHeader) & respondWithAllCorsHeaders } } diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 88f78d9..5fd9417 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -110,22 +110,25 @@ object `package` { } } - def respondWithCorsHeaders: Directive0 = { - optionalHeaderValueByType[Origin](()) flatMap { originHeader => - respondWithHeaders( - List[HttpHeader]( - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - )) + def respondWithCorsAllowedHeaders: Directive0 = { + respondWithHeaders( + List[HttpHeader]( + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + )) + } + + def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { + respondWithHeader { + `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) } } - def respondWithCorsAllowedMethodHeaders(methods: scala.collection.immutable.Seq[HttpMethod]): Directive0 = { + def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { respondWithHeaders( List[HttpHeader]( - Allow(methods), - `Access-Control-Allow-Methods`(methods) + Allow(methods.to[collection.immutable.Seq]), + `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) )) } |