aboutsummaryrefslogtreecommitdiff
path: root/core-rest/src/test/scala/xyz/driver/core/rest
diff options
context:
space:
mode:
Diffstat (limited to 'core-rest/src/test/scala/xyz/driver/core/rest')
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala89
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala121
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala101
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala151
4 files changed, 462 insertions, 0 deletions
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
new file mode 100644
index 0000000..324c8d8
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
@@ -0,0 +1,89 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.model.{HttpMethod, StatusCodes}
+import akka.http.scaladsl.server.{Directives, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.config.ConfigFactory
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.app.{DriverApp, SimpleModule}
+
+class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives {
+ val config = ConfigFactory.parseString("""
+ |application {
+ | cors {
+ | allowedOrigins: ["example.com"]
+ | }
+ |}
+ """.stripMargin).withFallback(ConfigFactory.load)
+
+ val origin = Origin(HttpOrigin("https", Host("example.com")))
+ val allowedOrigins = Set(HttpOrigin("https", Host("example.com")))
+ val allowedMethods: collection.immutable.Seq[HttpMethod] = {
+ import akka.http.scaladsl.model.HttpMethods._
+ collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE)
+ }
+
+ import scala.reflect.runtime.universe.typeOf
+ class TestApp(testRoute: Route)
+ extends DriverApp(
+ appName = "test-app",
+ version = "0.0.1",
+ gitHash = "deadb33f",
+ modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])),
+ config = config,
+ log = xyz.driver.core.logging.NoLogger
+ )
+
+ it should "respond with the correct CORS headers for the swagger OPTIONS route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with the correct CORS headers for the test route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ }
+ }
+
+ it should "respond with the correct CORS headers for a concatenated route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK)))
+ Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ }
+ }
+
+ it should "allow subdomains of allowed origin suffixes" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com"))))
+ }
+ }
+
+ it should "respond with default domains for invalid origins" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*))
+ }
+ }
+
+ it should "respond with Pragma and Cache-Control (no-cache) headers" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test") ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ header("Pragma").map(_.value()) should contain("no-cache")
+ header[`Cache-Control`].map(_.value()) should contain("no-cache")
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
new file mode 100644
index 0000000..cc0019a
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
@@ -0,0 +1,121 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.model.headers.Connection
+import akka.http.scaladsl.server.Directives.{complete => akkaComplete}
+import akka.http.scaladsl.server.{Directives, RejectionHandler, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.scalalogging.Logger
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.json.serviceExceptionFormat
+import xyz.driver.core.logging.NoLogger
+import xyz.driver.core.rest.errors._
+
+import scala.concurrent.Future
+
+class DriverRouteTest
+ extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives {
+ class TestRoute(override val route: Route) extends DriverRoute {
+ override def log: Logger = NoLogger
+ }
+
+ "DriverRoute" should "respond with 200 OK for a basic route" in {
+ val route = new TestRoute(akkaComplete(StatusCodes.OK))
+
+ Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.OK
+ }
+ }
+
+ it should "respond with a 401 for an InvalidInputException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ responseAs[ServiceException] shouldBe InvalidInputException()
+ }
+ }
+
+ it should "respond with a 403 for InvalidActionException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.Forbidden
+ responseAs[ServiceException] shouldBe InvalidActionException()
+ }
+ }
+
+ it should "respond with a 404 for ResourceNotFoundException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.NotFound
+ responseAs[ServiceException] shouldBe ResourceNotFoundException()
+ }
+ }
+
+ it should "respond with a 500 for ExternalServiceException" in {
+ val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None)
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[ServiceException] shouldBe error
+ }
+ }
+
+ it should "allow pass-through of external service exceptions" in {
+ val innerError = InvalidInputException()
+ val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError))
+ val future = Future.failed[String](error)
+ val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ responseAs[ServiceException] shouldBe innerError
+ }
+ }
+
+ it should "respond with a 503 for ExternalServiceTimeoutException" in {
+ val error = ExternalServiceTimeoutException("GET /api/v1/users/")
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.GatewayTimeout
+ responseAs[ServiceException] shouldBe error
+ }
+ }
+
+ it should "respond with a 500 for DatabaseException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[ServiceException] shouldBe DatabaseException()
+ }
+ }
+
+ it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in {
+ val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result()
+ val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK"))))
+
+ Get("/foo") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(Connection("close"))
+ }
+
+ Get("/bar") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.NotFound
+ headers should contain(Connection("close"))
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
new file mode 100644
index 0000000..987717d
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
@@ -0,0 +1,101 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.headers.`Content-Type`
+import akka.http.scaladsl.server.{Directives, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import org.scalatest.{FlatSpec, Matchers}
+import spray.json._
+import xyz.driver.core.{Id, Name}
+import xyz.driver.core.json._
+
+import scala.concurrent.Future
+
+class PatchDirectivesTest
+ extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol
+ with Directives with PatchDirectives {
+ case class Bar(name: Name[Bar], size: Int)
+ case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar])
+ implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar)
+ implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo)
+
+ val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10)))
+
+ def route(retrieve: => Future[Option[Foo]]): Route =
+ Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId =>
+ entity(as[Patchable[Foo]]) { fooPatchable =>
+ mergePatch(fooPatchable, retrieve) { updatedFoo =>
+ complete(updatedFoo)
+ }
+ }
+ })
+
+ val MergePatchContentType = ContentType(`application/merge-patch+json`)
+ val ContentTypeHeader = `Content-Type`(MergePatchContentType)
+ def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity =
+ HttpEntity(contentType, json)
+
+ "PatchSupport" should "allow partial updates to an existing object" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(rank = 4)
+ }
+ }
+
+ it should "merge deeply nested objects" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10)))
+ }
+ }
+
+ it should "return a 404 if the object is not found" in {
+ val fooRetrieve = Future.successful(None)
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.NotFound
+ }
+ }
+
+ it should "handle nulls on optional values correctly" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(bar = None)
+ }
+ }
+
+ it should "handle optional values correctly when old value is null" in {
+ val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None)))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10)))
+ }
+ }
+
+ it should "return a 400 for nulls on non-optional values" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ }
+ }
+
+ it should "return a 415 for incorrect Content-Type" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check {
+ status shouldBe StatusCodes.UnsupportedMediaType
+ responseAs[String] should include("application/merge-patch+json")
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala
new file mode 100644
index 0000000..19e4ed1
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala
@@ -0,0 +1,151 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.server.{Directives, Route, ValidationRejection}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import akka.util.ByteString
+import org.scalatest.{Matchers, WordSpec}
+import xyz.driver.core.rest
+
+import scala.concurrent.Future
+import scala.util.Random
+
+class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives {
+ "`escapeScriptTags` function" should {
+ "escape script tags properly" in {
+ val dirtyString = "</sc----</sc----</sc"
+ val cleanString = "--------------------"
+
+ (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc"))
+
+ (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString)
+ }
+ }
+
+ "paginated directive" should {
+ val route: Route = rest.paginated { paginated =>
+ complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}")
+ }
+ "accept a pagination" in {
+ Get("/?pageNumber=2&pageSize=42") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "2,42")
+ }
+ }
+ "provide a default pagination" in {
+ Get("/") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "1,100")
+ }
+ }
+ "provide default values for a partial pagination" in {
+ Get("/?pageSize=2") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "1,2")
+ }
+ }
+ "reject an invalid pagination" in {
+ Get("/?pageNumber=-1") ~> route ~> check {
+ assert(rejection.isInstanceOf[ValidationRejection])
+ }
+ }
+ }
+
+ "optional paginated directive" should {
+ val route: Route = rest.optionalPagination { paginated =>
+ complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination"))
+ }
+ "accept a pagination" in {
+ Get("/?pageNumber=2&pageSize=42") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "2,42")
+ }
+ }
+ "without pagination" in {
+ Get("/") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "no pagination")
+ }
+ }
+ "reject an invalid pagination" in {
+ Get("/?pageNumber=1") ~> route ~> check {
+ assert(rejection.isInstanceOf[ValidationRejection])
+ }
+ }
+ }
+
+ "completeWithPagination directive" when {
+ import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
+ import spray.json.DefaultJsonProtocol._
+
+ val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString)
+ val route: Route =
+ parameter('empty.as[Boolean] ? false) { isEmpty =>
+ completeWithPagination[String] {
+ case Some(pagination) if isEmpty =>
+ Future.successful(ListResponse(Seq(), 0, Some(pagination)))
+ case Some(pagination) =>
+ val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize)
+ Future.successful(ListResponse(filtered, data.size, Some(pagination)))
+ case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None))
+ case None => Future.successful(ListResponse(data, data.size, None))
+ }
+ }
+
+ "pagination is defined" should {
+ "return a response with pagination headers" in {
+ Get("/?pageNumber=2&pageSize=10") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe data.slice(10, 20)
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("103")
+ header(ContextHeaders.PageCount).map(_.value) should contain("11")
+ }
+ }
+
+ "disallow pageSize <= 0" in {
+ Get("/?pageNumber=2&pageSize=0") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+
+ Get("/?pageNumber=2&pageSize=-1") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+ }
+
+ "disallow pageNumber <= 0" in {
+ Get("/?pageNumber=0&pageSize=10") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+
+ Get("/?pageNumber=-1&pageSize=10") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+ }
+
+ "return PageCount == 0 if returning an empty list" in {
+ Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe empty
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("0")
+ header(ContextHeaders.PageCount).map(_.value) should contain("0")
+ }
+ }
+ }
+
+ "pagination is not defined" should {
+ "return a response with pagination headers and PageCount == 1" in {
+ Get("/") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe data
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("103")
+ header(ContextHeaders.PageCount).map(_.value) should contain("1")
+ }
+ }
+
+ "return PageCount == 0 if returning an empty list" in {
+ Get("/?empty=true") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe empty
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("0")
+ header(ContextHeaders.PageCount).map(_.value) should contain("0")
+ }
+ }
+ }
+ }
+}