aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachdriver <zach@driver.xyz>2018-02-22 11:31:51 -0800
committerGitHub <noreply@github.com>2018-02-22 11:31:51 -0800
commitd48406367be713fe8fd24f8dea15545817906cd4 (patch)
tree0ae98459f1c3d28953922a0ab1170fa83fc5f85b
parentd534ce2309052329506b31b17ae137950757890f (diff)
parent2059634d2fa2c28ddf2b992bc36ab3d96f3c2512 (diff)
downloaddriver-core-d48406367be713fe8fd24f8dea15545817906cd4.tar.gz
driver-core-d48406367be713fe8fd24f8dea15545817906cd4.tar.bz2
driver-core-d48406367be713fe8fd24f8dea15545817906cd4.zip
Merge pull request #104 from drivergroup/zsmith/reject-corsv1.8.2
Respond with correct cors headers for all rejections
-rw-r--r--src/main/resources/reference.conf20
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala33
-rw-r--r--src/main/scala/xyz/driver/core/app/module.scala5
-rw-r--r--src/main/scala/xyz/driver/core/rest/DriverRoute.scala84
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala24
-rw-r--r--src/test/scala/xyz/driver/core/rest/DriverAppTest.scala58
-rw-r--r--src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala61
7 files changed, 177 insertions, 108 deletions
diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf
index 16dcfda..74ad206 100644
--- a/src/main/resources/reference.conf
+++ b/src/main/resources/reference.conf
@@ -2,7 +2,7 @@
# Default settings for driver core. Any settings defined by users of #
# this library will take precedence. See the documentation of the #
# Typesafe Config Library (https://github.com/lightbend/config) for #
-# more information. #
+# more information. #
######################################################################
# This scope is for general settings related to the execution of a
@@ -10,6 +10,24 @@
application {
baseUrl: "localhost:8080"
environment: "local_testing"
+
+ cors {
+ allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"]
+ allowedOrigins: [
+ {
+ scheme: http
+ hostSuffix: localhost
+ },
+ {
+ scheme: https
+ hostSuffix: driver.xyz
+ },
+ {
+ scheme: https
+ hostSuffix: driver.network
+ }
+ ]
+ }
}
# Settings about the auto-generated REST API documentation.
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index df80c3d..1ded4dd 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -2,7 +2,6 @@ package xyz.driver.core.app
import akka.actor.ActorSystem
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
-import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
@@ -42,7 +41,6 @@ class DriverApp(
port: Int = 8080,
tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) {
self =>
- import DriverApp._
implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem)
private lazy val http: HttpExt = Http()(actorSystem)
@@ -73,8 +71,9 @@ class DriverApp(
val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI
val versionRt = versionRoute(version, gitHash, time.currentTime())
val basicRoutes = new DriverRoute {
- override def log: Logger = self.log
- override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
+ override def log: Logger = self.log
+ override def config: Config = self.config
+ override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
}
val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _)
@@ -221,29 +220,3 @@ class DriverApp(
})
}
}
-
-object DriverApp {
- implicit def rejectionHandler: RejectionHandler =
- RejectionHandler
- .newBuilder()
- .handleAll[MethodRejection] { rejections =>
- val methods = rejections map (_.supported)
- lazy val names = methods map (_.name) mkString ", "
-
- options { ctx =>
- optionalHeaderValueByType[Origin](()) { originHeader =>
- respondWithHeaders(List[HttpHeader](
- Allow(methods),
- `Access-Control-Allow-Methods`(methods),
- allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(AllowedHeaders: _*)
- )) {
- complete(s"Supported methods: $names.")
- }
- }(ctx)
- } ~
- complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!")
- }
- .result()
-}
diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala
index 7be38eb..0a255fb 100644
--- a/src/main/scala/xyz/driver/core/app/module.scala
+++ b/src/main/scala/xyz/driver/core/app/module.scala
@@ -30,8 +30,9 @@ class EmptyModule extends Module {
class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module {
private val driverRoute: DriverRoute = new DriverRoute {
- override def route: Route = theRoute
- override val log: Logger = xyz.driver.core.logging.NoLogger
+ override def route: Route = theRoute
+ override val config: Config = xyz.driver.core.config.loadDefaultConfig
+ override val log: Logger = xyz.driver.core.logging.NoLogger
}
override def route: Route = driverRoute.routeWithDefaults
diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
index 4c483c6..58a4143 100644
--- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
+++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
@@ -6,7 +6,8 @@ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
-import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route}
+import akka.http.scaladsl.server._
+import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import org.slf4j.MDC
import xyz.driver.core.rest
@@ -16,28 +17,85 @@ import scala.compat.Platform.ConcurrentModificationException
trait DriverRoute {
def log: Logger
+ def config: Config
def route: Route
def routeWithDefaults: Route = {
- (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route)
+ (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) {
+ route ~ defaultOptionsRoute
+ }
+ }
+
+ protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
+ import scala.collection.JavaConverters._
+ config
+ .getConfigList("application.cors.allowedOrigins")
+ .asScala
+ .map { c =>
+ HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
+ }(scala.collection.breakOut)
+ }
+
+ protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
+ import scala.collection.JavaConverters._
+ config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
+ }
+
+ protected lazy val defaultCorsAllowedOrigin: Origin = {
+ Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
+ }
+
+ protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
+ val allowedOrigin =
+ origin
+ .filter { requestOrigin =>
+ allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
+ requestOrigin.origins.exists(o =>
+ o.scheme == allowedOriginSuffix.scheme &&
+ o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
+ }
+ }
+ .getOrElse(defaultCorsAllowedOrigin)
+
+ `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
+ }
+
+ protected def respondWithAllCorsHeaders: Directive0 = {
+ respondWithCorsAllowedHeaders tflatMap { _ =>
+ respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
+ optionalHeaderValueByType[Origin](()) flatMap { origin =>
+ respondWithHeader(corsAllowedOriginHeader(origin))
+ }
+ }
+ }
+ }
+
+ protected def defaultOptionsRoute: Route = options {
+ respondWithAllCorsHeaders {
+ complete("OK")
+ }
}
protected def defaultResponseHeaders: Directive0 = {
- (extractRequest & optionalHeaderValueByType[Origin](())) tflatMap {
- case (request, originHeader) =>
- val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
- val responseHeaders = List[HttpHeader](
- tracingHeader,
- allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(AllowedHeaders: _*)
- )
-
- respondWithHeaders(responseHeaders)
+ extractRequest flatMap { request =>
+ val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
+ respondWithHeader(tracingHeader) & respondWithAllCorsHeaders
}
}
+ protected def rejectionHandler: RejectionHandler =
+ RejectionHandler
+ .newBuilder()
+ .handle {
+ case rejection =>
+ respondWithAllCorsHeaders {
+ RejectionHandler.default(collection.immutable.Seq(rejection)).get
+ }
+ }
+ .result()
+ .seal
+
/**
* Override me for custom exception handling
*
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
index 77b916a..5fd9417 100644
--- a/src/main/scala/xyz/driver/core/rest/package.scala
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -3,7 +3,7 @@ package xyz.driver.core.rest
import java.net.InetAddress
import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable}
-import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`}
+import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
@@ -110,6 +110,28 @@ object `package` {
}
}
+ def respondWithCorsAllowedHeaders: Directive0 = {
+ respondWithHeaders(
+ List[HttpHeader](
+ `Access-Control-Allow-Headers`(AllowedHeaders: _*),
+ `Access-Control-Expose-Headers`(AllowedHeaders: _*)
+ ))
+ }
+
+ def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = {
+ respondWithHeader {
+ `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*))
+ }
+ }
+
+ def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = {
+ respondWithHeaders(
+ List[HttpHeader](
+ Allow(methods.to[collection.immutable.Seq]),
+ `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq])
+ ))
+ }
+
def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext =
new ServiceRequestContext(
extractTrackingId(request),
diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
deleted file mode 100644
index f5602be..0000000
--- a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-package xyz.driver.core.rest
-
-import akka.http.scaladsl.model.headers._
-import akka.http.scaladsl.model.{HttpMethods, StatusCodes}
-import akka.http.scaladsl.server.Directives._
-import akka.http.scaladsl.server.Route
-import akka.http.scaladsl.settings.RoutingSettings
-import akka.http.scaladsl.testkit.ScalatestRouteTest
-import com.typesafe.config.Config
-import com.typesafe.scalalogging.Logger
-import org.scalatest.{FlatSpec, Matchers}
-import xyz.driver.core.app.{DriverApp, Module}
-
-import scala.reflect.runtime.universe._
-
-class DriverAppTest extends FlatSpec with ScalatestRouteTest with Matchers {
- class TestRoute extends DriverRoute {
- override def log: Logger = xyz.driver.core.logging.NoLogger
- override def route: Route = path("api" / "v1" / "test")(post(complete("OK")))
- }
-
- val module: Module = new Module {
- val testRoute = new TestRoute
- override def route: Route = testRoute.routeWithDefaults
- override def routeTypes: Seq[Type] = Seq(typeOf[TestRoute])
- override val name: String = "test-module"
- }
-
- val app: DriverApp = new DriverApp(
- appName = "test-app",
- version = "0.1",
- gitHash = "deadb33f",
- modules = Seq(module)
- )
-
- val config: Config = xyz.driver.core.config.loadDefaultConfig
- val routingSettings: RoutingSettings = RoutingSettings(config)
- val appRoute: Route =
- Route.seal(app.appRoute)(routingSettings = routingSettings, rejectionHandler = DriverApp.rejectionHandler)
-
- "DriverApp" should "respond with the correct CORS headers for the swagger OPTIONS route" in {
- Options(s"/api-docs/swagger.json") ~> appRoute ~> check {
- status shouldBe StatusCodes.OK
- info(response.toString())
- headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*))
- headers should contain(`Access-Control-Allow-Methods`(HttpMethods.GET))
- }
- }
-
- it should "respond with the correct CORS headers for the test route" in {
- Options(s"/api/v1/test") ~> appRoute ~> check {
- status shouldBe StatusCodes.OK
- info(response.toString())
- headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*))
- headers should contain(`Access-Control-Allow-Methods`(HttpMethods.GET, HttpMethods.POST))
- }
- }
-}
diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
index f402261..60056b7 100644
--- a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
+++ b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
@@ -1,9 +1,11 @@
package xyz.driver.core.rest
-import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.model.{HttpMethod, StatusCodes}
+import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives.{complete => akkaComplete}
-import akka.http.scaladsl.server.Route
+import akka.http.scaladsl.server.{Directives, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.config.{Config, ConfigFactory}
import com.typesafe.scalalogging.Logger
import org.scalatest.{AsyncFlatSpec, Matchers}
import xyz.driver.core.logging.NoLogger
@@ -11,9 +13,24 @@ import xyz.driver.core.rest.errors._
import scala.concurrent.Future
-class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers {
+class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives {
class TestRoute(override val route: Route) extends DriverRoute {
override def log: Logger = NoLogger
+ override def config: Config =
+ ConfigFactory.parseString("""
+ |application {
+ | cors {
+ | allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"]
+ | allowedOrigins: [{scheme: https, hostSuffix: example.com}]
+ | }
+ |}
+ """.stripMargin)
+ }
+
+ val allowedOrigins = Set(HttpOrigin("https", Host("example.com")))
+ val allowedMethods: collection.immutable.Seq[HttpMethod] = {
+ import akka.http.scaladsl.model.HttpMethods._
+ collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS)
}
"DriverRoute" should "respond with 200 OK for a basic route" in {
@@ -86,4 +103,42 @@ class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matcher
responseAs[String] shouldBe "Database access error"
}
}
+
+ it should "respond with the correct CORS headers for the swagger OPTIONS route" in {
+ val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
+ Options(s"/api-docs/swagger.json") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with the correct CORS headers for the test route" in {
+ val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
+ Options(s"/api/v1/test") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "allow subdomains of allowed origin suffixes" in {
+ val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
+ Options(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com"))))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with default domains for invalid origins" in {
+ val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
+ Options(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
}