aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachdriver <zach@driver.xyz>2018-02-22 17:47:18 -0800
committerGitHub <noreply@github.com>2018-02-22 17:47:18 -0800
commit4de3c543151fb230903d8302ced30c916683a3af (patch)
tree6b0254ea1d5430415412a3be92696f21bb4a4670
parentd48406367be713fe8fd24f8dea15545817906cd4 (diff)
parent5d6ef236eb3b5a4fad6383b87e1572685ddf2bc5 (diff)
downloaddriver-core-4de3c543151fb230903d8302ced30c916683a3af.tar.gz
driver-core-4de3c543151fb230903d8302ced30c916683a3af.tar.bz2
driver-core-4de3c543151fb230903d8302ced30c916683a3af.zip
Merge pull request #124 from drivergroup/zsmith/remove-sealv1.8.3
Fix CORS by moving directives to DriverApp and removing custom RejectionHandler
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala57
-rw-r--r--src/main/scala/xyz/driver/core/rest/DriverRoute.scala68
-rw-r--r--src/test/scala/xyz/driver/core/rest/DriverAppTest.scala84
-rw-r--r--src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala47
4 files changed, 143 insertions, 113 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index 1ded4dd..ca3dd54 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -62,9 +62,59 @@ class DriverApp(
}
}
+ protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
+ import scala.collection.JavaConverters._
+ config
+ .getConfigList("application.cors.allowedOrigins")
+ .asScala
+ .map { c =>
+ HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
+ }(scala.collection.breakOut)
+ }
+
+ protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
+ import scala.collection.JavaConverters._
+ config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
+ }
+
+ protected lazy val defaultCorsAllowedOrigin: Origin = {
+ Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
+ }
+
+ protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
+ val allowedOrigin =
+ origin
+ .filter { requestOrigin =>
+ allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
+ requestOrigin.origins.exists(o =>
+ o.scheme == allowedOriginSuffix.scheme &&
+ o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
+ }
+ }
+ .getOrElse(defaultCorsAllowedOrigin)
+
+ `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
+ }
+
+ protected def respondWithAllCorsHeaders: Directive0 = {
+ respondWithCorsAllowedHeaders tflatMap { _ =>
+ respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
+ optionalHeaderValueByType[Origin](()) flatMap { origin =>
+ respondWithHeader(corsAllowedOriginHeader(origin))
+ }
+ }
+ }
+ }
+
private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
request.headers.find(_.name().toLowerCase === headerName).map(_.value())
+ protected def defaultOptionsRoute: Route = options {
+ respondWithAllCorsHeaders {
+ complete("OK")
+ }
+ }
+
def appRoute: Route = {
val serviceTypes = modules.flatMap(_.routeTypes)
val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log)
@@ -75,7 +125,8 @@ class DriverApp(
override def config: Config = self.config
override def route: Route = versionRt ~ healthRoute ~ swaggerRoute
}
- val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _)
+ val combinedRoute =
+ Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute)
(extractHost & extractClientIP & trace(tracer)) {
case (origin, ip) =>
@@ -97,7 +148,9 @@ class DriverApp(
.addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
.addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))
- combinedRoute(contextWithTrackingId)
+ respondWithAllCorsHeaders {
+ combinedRoute
+ }(contextWithTrackingId)
}
}
diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
index 58a4143..15da808 100644
--- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
+++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
@@ -22,80 +22,18 @@ trait DriverRoute {
def route: Route
def routeWithDefaults: Route = {
- (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) {
- route ~ defaultOptionsRoute
- }
- }
-
- protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
- import scala.collection.JavaConverters._
- config
- .getConfigList("application.cors.allowedOrigins")
- .asScala
- .map { c =>
- HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
- }(scala.collection.breakOut)
- }
-
- protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
- import scala.collection.JavaConverters._
- config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
- }
-
- protected lazy val defaultCorsAllowedOrigin: Origin = {
- Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
- }
-
- protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
- val allowedOrigin =
- origin
- .filter { requestOrigin =>
- allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
- requestOrigin.origins.exists(o =>
- o.scheme == allowedOriginSuffix.scheme &&
- o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
- }
- }
- .getOrElse(defaultCorsAllowedOrigin)
-
- `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
- }
-
- protected def respondWithAllCorsHeaders: Directive0 = {
- respondWithCorsAllowedHeaders tflatMap { _ =>
- respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
- optionalHeaderValueByType[Origin](()) flatMap { origin =>
- respondWithHeader(corsAllowedOriginHeader(origin))
- }
- }
- }
- }
-
- protected def defaultOptionsRoute: Route = options {
- respondWithAllCorsHeaders {
- complete("OK")
+ (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) {
+ route
}
}
protected def defaultResponseHeaders: Directive0 = {
extractRequest flatMap { request =>
val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
- respondWithHeader(tracingHeader) & respondWithAllCorsHeaders
+ respondWithHeader(tracingHeader)
}
}
- protected def rejectionHandler: RejectionHandler =
- RejectionHandler
- .newBuilder()
- .handle {
- case rejection =>
- respondWithAllCorsHeaders {
- RejectionHandler.default(collection.immutable.Seq(rejection)).get
- }
- }
- .result()
- .seal
-
/**
* Override me for custom exception handling
*
diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
new file mode 100644
index 0000000..eda6a8c
--- /dev/null
+++ b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
@@ -0,0 +1,84 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.{HttpMethod, StatusCodes}
+import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.server.{Directives, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.config.ConfigFactory
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.app.{DriverApp, SimpleModule}
+
+class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives {
+ val config = ConfigFactory.parseString("""
+ |application {
+ | cors {
+ | allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"]
+ | allowedOrigins: [{scheme: https, hostSuffix: example.com}]
+ | }
+ |}
+ """.stripMargin).withFallback(ConfigFactory.load)
+
+ val allowedOrigins = Set(HttpOrigin("https", Host("example.com")))
+ val allowedMethods: collection.immutable.Seq[HttpMethod] = {
+ import akka.http.scaladsl.model.HttpMethods._
+ collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS)
+ }
+
+ import scala.reflect.runtime.universe.typeOf
+ class TestApp(testRoute: Route)
+ extends DriverApp(
+ appName = "test-app",
+ version = "0.0.1",
+ gitHash = "deadb33f",
+ modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])),
+ config = config,
+ log = xyz.driver.core.logging.NoLogger
+ )
+
+ it should "respond with the correct CORS headers for the swagger OPTIONS route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Options(s"/api-docs/swagger.json") ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with the correct CORS headers for the test route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test") ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with the correct CORS headers for a concatenated route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK)))
+ Post(s"/api/v1/test") ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "allow subdomains of allowed origin suffixes" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com"))))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with default domains for invalid origins" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+}
diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
index 60056b7..ce3daa8 100644
--- a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
+++ b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
@@ -1,7 +1,6 @@
package xyz.driver.core.rest
-import akka.http.scaladsl.model.{HttpMethod, StatusCodes}
-import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Directives.{complete => akkaComplete}
import akka.http.scaladsl.server.{Directives, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest
@@ -27,12 +26,6 @@ class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matcher
""".stripMargin)
}
- val allowedOrigins = Set(HttpOrigin("https", Host("example.com")))
- val allowedMethods: collection.immutable.Seq[HttpMethod] = {
- import akka.http.scaladsl.model.HttpMethods._
- collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS)
- }
-
"DriverRoute" should "respond with 200 OK for a basic route" in {
val route = new TestRoute(akkaComplete(StatusCodes.OK))
@@ -103,42 +96,4 @@ class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matcher
responseAs[String] shouldBe "Database access error"
}
}
-
- it should "respond with the correct CORS headers for the swagger OPTIONS route" in {
- val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
- Options(s"/api-docs/swagger.json") ~> route.routeWithDefaults ~> check {
- status shouldBe StatusCodes.OK
- headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
- header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
- }
- }
-
- it should "respond with the correct CORS headers for the test route" in {
- val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
- Options(s"/api/v1/test") ~> route.routeWithDefaults ~> check {
- status shouldBe StatusCodes.OK
- headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
- header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
- }
- }
-
- it should "allow subdomains of allowed origin suffixes" in {
- val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
- Options(s"/api/v1/test")
- .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.routeWithDefaults ~> check {
- status shouldBe StatusCodes.OK
- headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com"))))
- header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
- }
- }
-
- it should "respond with default domains for invalid origins" in {
- val route = new TestRoute(get(akkaComplete(StatusCodes.OK)))
- Options(s"/api/v1/test")
- .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.routeWithDefaults ~> check {
- status shouldBe StatusCodes.OK
- headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
- header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
- }
- }
}