diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/DriverRoute.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/DriverRoute.scala | 84 |
1 files changed, 71 insertions, 13 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala index 4c483c6..58a4143 100644 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -6,7 +6,8 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route} +import akka.http.scaladsl.server._ +import com.typesafe.config.Config import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.rest @@ -16,28 +17,85 @@ import scala.compat.Platform.ConcurrentModificationException trait DriverRoute { def log: Logger + def config: Config def route: Route def routeWithDefaults: Route = { - (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route) + (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) { + route ~ defaultOptionsRoute + } + } + + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = { + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + } + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } } protected def defaultResponseHeaders: Directive0 = { - (extractRequest & optionalHeaderValueByType[Origin](())) tflatMap { - case (request, originHeader) => - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) - val responseHeaders = List[HttpHeader]( - tracingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) + extractRequest flatMap { request => + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) + respondWithHeader(tracingHeader) & respondWithAllCorsHeaders } } + protected def rejectionHandler: RejectionHandler = + RejectionHandler + .newBuilder() + .handle { + case rejection => + respondWithAllCorsHeaders { + RejectionHandler.default(collection.immutable.Seq(rejection)).get + } + } + .result() + .seal + /** * Override me for custom exception handling * |