From 5d6ef236eb3b5a4fad6383b87e1572685ddf2bc5 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Thu, 22 Feb 2018 17:03:26 -0800 Subject: Move cors directives to DriverApp and get rid of custom rejection handler --- src/main/scala/xyz/driver/core/app/DriverApp.scala | 57 ++++++++++++++- .../scala/xyz/driver/core/rest/DriverRoute.scala | 67 +---------------- .../scala/xyz/driver/core/rest/DriverAppTest.scala | 84 ++++++++++++++++++++++ .../xyz/driver/core/rest/DriverRouteTest.scala | 47 +----------- 4 files changed, 143 insertions(+), 112 deletions(-) create mode 100644 src/test/scala/xyz/driver/core/rest/DriverAppTest.scala diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 1ded4dd..ca3dd54 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -62,9 +62,59 @@ class DriverApp( } } + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = { + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + } + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } + } + def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) @@ -75,7 +125,8 @@ class DriverApp( override def config: Config = self.config override def route: Route = versionRt ~ healthRoute ~ swaggerRoute } - val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) + val combinedRoute = + Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute) (extractHost & extractClientIP & trace(tracer)) { case (origin, ip) => @@ -97,7 +148,9 @@ class DriverApp( .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - combinedRoute(contextWithTrackingId) + respondWithAllCorsHeaders { + combinedRoute + }(contextWithTrackingId) } } diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala index 32d996a..15da808 100644 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -22,79 +22,18 @@ trait DriverRoute { def route: Route def routeWithDefaults: Route = { - (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) { - route ~ defaultOptionsRoute - } - } - - protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { - import scala.collection.JavaConverters._ - config - .getConfigList("application.cors.allowedOrigins") - .asScala - .map { c => - HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) - }(scala.collection.breakOut) - } - - protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { - import scala.collection.JavaConverters._ - config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) - } - - protected lazy val defaultCorsAllowedOrigin: Origin = { - Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) - } - - protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { - val allowedOrigin = - origin - .filter { requestOrigin => - allowedCorsDomainSuffixes.exists { allowedOriginSuffix => - requestOrigin.origins.exists(o => - o.scheme == allowedOriginSuffix.scheme && - o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) - } - } - .getOrElse(defaultCorsAllowedOrigin) - - `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) - } - - protected def respondWithAllCorsHeaders: Directive0 = { - respondWithCorsAllowedHeaders tflatMap { _ => - respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => - optionalHeaderValueByType[Origin](()) flatMap { origin => - respondWithHeader(corsAllowedOriginHeader(origin)) - } - } - } - } - - protected def defaultOptionsRoute: Route = options { - respondWithAllCorsHeaders { - complete("OK") + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route } } protected def defaultResponseHeaders: Directive0 = { extractRequest flatMap { request => val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) - respondWithHeader(tracingHeader) & respondWithAllCorsHeaders + respondWithHeader(tracingHeader) } } - protected def rejectionHandler: RejectionHandler = - RejectionHandler - .newBuilder() - .handle { - case rejection => - respondWithAllCorsHeaders { - RejectionHandler.default(collection.immutable.Seq(rejection)).get - } - } - .result() - /** * Override me for custom exception handling * diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..eda6a8c --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,84 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{HttpMethod, StatusCodes} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.ConfigFactory +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, SimpleModule} + +class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { + val config = ConfigFactory.parseString(""" + |application { + | cors { + | allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"] + | allowedOrigins: [{scheme: https, hostSuffix: example.com}] + | } + |} + """.stripMargin).withFallback(ConfigFactory.load) + + val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) + val allowedMethods: collection.immutable.Seq[HttpMethod] = { + import akka.http.scaladsl.model.HttpMethods._ + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS) + } + + import scala.reflect.runtime.universe.typeOf + class TestApp(testRoute: Route) + extends DriverApp( + appName = "test-app", + version = "0.0.1", + gitHash = "deadb33f", + modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), + config = config, + log = xyz.driver.core.logging.NoLogger + ) + + it should "respond with the correct CORS headers for the swagger OPTIONS route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Options(s"/api-docs/swagger.json") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for the test route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for a concatenated route" in { + val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) + Post(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "allow subdomains of allowed origin suffixes" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with default domains for invalid origins" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } +} diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala index 60056b7..ce3daa8 100644 --- a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala +++ b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala @@ -1,7 +1,6 @@ package xyz.driver.core.rest -import akka.http.scaladsl.model.{HttpMethod, StatusCodes} -import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.server.Directives.{complete => akkaComplete} import akka.http.scaladsl.server.{Directives, Route} import akka.http.scaladsl.testkit.ScalatestRouteTest @@ -27,12 +26,6 @@ class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matcher """.stripMargin) } - val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) - val allowedMethods: collection.immutable.Seq[HttpMethod] = { - import akka.http.scaladsl.model.HttpMethods._ - collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS) - } - "DriverRoute" should "respond with 200 OK for a basic route" in { val route = new TestRoute(akkaComplete(StatusCodes.OK)) @@ -103,42 +96,4 @@ class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matcher responseAs[String] shouldBe "Database access error" } } - - it should "respond with the correct CORS headers for the swagger OPTIONS route" in { - val route = new TestRoute(get(akkaComplete(StatusCodes.OK))) - Options(s"/api-docs/swagger.json") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } - - it should "respond with the correct CORS headers for the test route" in { - val route = new TestRoute(get(akkaComplete(StatusCodes.OK))) - Options(s"/api/v1/test") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } - - it should "allow subdomains of allowed origin suffixes" in { - val route = new TestRoute(get(akkaComplete(StatusCodes.OK))) - Options(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } - - it should "respond with default domains for invalid origins" in { - val route = new TestRoute(get(akkaComplete(StatusCodes.OK))) - Options(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } } -- cgit v1.2.3