aboutsummaryrefslogtreecommitdiff
path: root/core-rest/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'core-rest/src/test')
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/AuthTest.scala165
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala264
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/JsonTest.scala521
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/TestTypes.scala14
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala89
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala121
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala101
-rw-r--r--core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala151
8 files changed, 1426 insertions, 0 deletions
diff --git a/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala
new file mode 100644
index 0000000..2e772fb
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala
@@ -0,0 +1,165 @@
+package xyz.driver.core
+
+import akka.http.scaladsl.model.headers.{
+ HttpChallenges,
+ OAuth2BearerToken,
+ RawHeader,
+ Authorization => AkkaAuthorization
+}
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.server._
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import org.scalatest.{FlatSpec, Matchers}
+import pdi.jwt.{Jwt, JwtAlgorithm}
+import xyz.driver.core.auth._
+import xyz.driver.core.domain.Email
+import xyz.driver.core.logging._
+import xyz.driver.core.rest._
+import xyz.driver.core.rest.auth._
+import xyz.driver.core.time.Time
+
+import scala.concurrent.Future
+import scalaz.OptionT
+
+class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest {
+
+ case object TestRoleAllowedPermission extends Permission
+ case object TestRoleAllowedByTokenPermission extends Permission
+ case object TestRoleNotAllowedPermission extends Permission
+
+ val TestRole = Role(Id("1"), Name("testRole"))
+
+ val (publicKey, privateKey) = {
+ import java.security.KeyPairGenerator
+
+ val keygen = KeyPairGenerator.getInstance("RSA")
+ keygen.initialize(2048)
+
+ val keyPair = keygen.generateKeyPair()
+ (keyPair.getPublic, keyPair.getPrivate)
+ }
+
+ val basicAuthorization: Authorization[User] = new Authorization[User] {
+
+ override def userHasPermissions(user: User, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = {
+ val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap
+ Future.successful(AuthorizationResult(authorized, ctx.permissionsToken))
+ }
+ }
+
+ val tokenIssuer = "users"
+ val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer)
+
+ val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization)
+
+ val authStatusService = new AuthProvider[User](authorization, NoLogger) {
+ override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] =
+ OptionT.optionT[Future] {
+ if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) {
+ Future.successful(
+ Some(
+ AuthTokenUserInfo(
+ Id[User]("1"),
+ Email("foo", "bar"),
+ emailVerified = true,
+ audience = "driver",
+ roles = Set(TestRole),
+ expirationTime = Time(1000000L)
+ )))
+ } else {
+ Future.successful(Option.empty[User])
+ }
+ }
+ }
+
+ import authStatusService._
+
+ "'authorize' directive" should "throw error if auth token is not in the request" in {
+
+ Get("/naive/attempt") ~>
+ authorize(TestRoleAllowedPermission) { user =>
+ complete("Never going to be here")
+ } ~>
+ check {
+ // handled shouldBe false
+ rejections should contain(
+ AuthenticationFailedRejection(
+ AuthenticationFailedRejection.CredentialsMissing,
+ HttpChallenges.oAuth2(authStatusService.realm)))
+ }
+ }
+
+ it should "throw error if authorized user does not have the requested permission" in {
+
+ val referenceAuthToken = AuthToken("I am a test role's token")
+ val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value))
+
+ Post("/administration/attempt").addHeader(
+ referenceAuthHeader
+ ) ~>
+ authorize(TestRoleNotAllowedPermission) { user =>
+ complete("Never going to get here")
+ } ~>
+ check {
+ handled shouldBe false
+ rejections should contain(AuthorizationFailedRejection)
+ }
+ }
+
+ it should "pass and retrieve the token to client code, if token is in request and user has permission" in {
+ val referenceAuthToken = AuthToken("I am token")
+ val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value))
+
+ Get("/valid/attempt/?a=2&b=5").addHeader(
+ referenceAuthHeader
+ ) ~>
+ authorize(TestRoleAllowedPermission) { ctx =>
+ complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized")
+ } ~>
+ check {
+ handled shouldBe true
+ responseAs[String] shouldBe "Alright, user 1 is authorized"
+ }
+ }
+
+ it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in {
+ val referenceAuthToken = AuthToken("unprefixed_token")
+
+ Get("/valid/attempt/?a=2&b=5").addHeader(
+ RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value)
+ ) ~>
+ authorize(TestRoleAllowedPermission) { ctx =>
+ complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized")
+ } ~>
+ check {
+ handled shouldBe true
+ responseAs[String] shouldBe "Alright, user 1 is authorized"
+ }
+ }
+
+ it should "authorize permission found in permissions token" in {
+ import spray.json._
+
+ val claim = JsObject(
+ Map(
+ "iss" -> JsString(tokenIssuer),
+ "sub" -> JsString("1"),
+ "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true)))
+ )).prettyPrint
+ val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256))
+ val referenceAuthToken = AuthToken("I am token")
+ val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value))
+
+ Get("/alic/attempt/?a=2&b=5")
+ .addHeader(referenceAuthHeader)
+ .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~>
+ authorize(TestRoleAllowedByTokenPermission) { ctx =>
+ complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token")
+ } ~>
+ check {
+ handled shouldBe true
+ responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token"
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala
new file mode 100644
index 0000000..7e740a4
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala
@@ -0,0 +1,264 @@
+package xyz.driver.core
+
+import org.scalatest.{Assertions, FlatSpec, Matchers}
+
+import scala.collection.immutable.IndexedSeq
+
+class GeneratorsTest extends FlatSpec with Matchers with Assertions {
+ import generators._
+
+ "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in {
+
+ val generatedId1 = nextId[String]()
+ val generatedId2 = nextId[String]()
+ val generatedId3 = nextId[Long]()
+
+ generatedId1.length should be >= 0
+ generatedId2.length should be >= 0
+ generatedId3.length should be >= 0
+ generatedId1 should not be generatedId2
+ generatedId2 should !==(generatedId3)
+ }
+
+ it should "be able to generate com.drivergrp.core.Id identifiers with max value" in {
+
+ val generatedLimitedId1 = nextId[String](5)
+ val generatedLimitedId2 = nextId[String](4)
+ val generatedLimitedId3 = nextId[Long](3)
+
+ generatedLimitedId1.length should be >= 0
+ generatedLimitedId1.length should be < 6
+ generatedLimitedId2.length should be >= 0
+ generatedLimitedId2.length should be < 5
+ generatedLimitedId3.length should be >= 0
+ generatedLimitedId3.length should be < 4
+ generatedLimitedId1 should not be generatedLimitedId2
+ generatedLimitedId2 should !==(generatedLimitedId3)
+ }
+
+ it should "be able to generate com.drivergrp.core.Name names" in {
+
+ Seq.fill(10)(nextName[String]()).distinct.size should be > 1
+ nextName[String]().value.length should be >= 0
+
+ val fixedLengthName = nextName[String](10)
+ fixedLengthName.length should be <= 10
+ assert(!fixedLengthName.value.exists(_.isControl))
+ }
+
+ it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in {
+
+ assert(nextNonEmptyName[String]().value.value.nonEmpty)
+ }
+
+ it should "be able to generate proper UUIDs" in {
+
+ nextUuid() should not be nextUuid()
+ nextUuid().toString.length should be(36)
+ }
+
+ it should "be able to generate new Revisions" in {
+
+ nextRevision[String]() should not be nextRevision[String]()
+ nextRevision[String]().id.length should be > 0
+ }
+
+ it should "be able to generate strings" in {
+
+ nextString() should not be nextString()
+ nextString().length should be >= 0
+
+ val fixedLengthString = nextString(20)
+ fixedLengthString.length should be <= 20
+ assert(!fixedLengthString.exists(_.isControl))
+ }
+
+ it should "be able to generate strings non-empty strings whic are non empty" in {
+
+ assert(nextNonEmptyString().value.nonEmpty)
+ }
+
+ it should "be able to generate options which are sometimes have values and sometimes not" in {
+
+ val generatedOption = nextOption("2")
+
+ generatedOption should not contain "1"
+ assert(generatedOption === Some("2") || generatedOption === None)
+ }
+
+ it should "be able to generate a pair of two generated values" in {
+
+ val constantPair = nextPair("foo", 1L)
+ constantPair._1 should be("foo")
+ constantPair._2 should be(1L)
+
+ val generatedPair = nextPair(nextId[Int](), nextName[Int]())
+
+ generatedPair._1.length should be > 0
+ generatedPair._2.length should be > 0
+
+ nextPair(nextId[Int](), nextName[Int]()) should not be
+ nextPair(nextId[Int](), nextName[Int]())
+ }
+
+ it should "be able to generate a triad of two generated values" in {
+
+ val constantTriad = nextTriad("foo", "bar", 1L)
+ constantTriad._1 should be("foo")
+ constantTriad._2 should be("bar")
+ constantTriad._3 should be(1L)
+
+ val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal())
+
+ generatedTriad._1.length should be > 0
+ generatedTriad._2.length should be > 0
+ generatedTriad._3 should be >= BigDecimal(0.00)
+
+ nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be
+ nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal())
+ }
+
+ it should "be able to generate a time value" in {
+
+ val generatedTime = nextTime()
+ val currentTime = System.currentTimeMillis()
+
+ generatedTime.millis should be >= 0L
+ generatedTime.millis should be <= currentTime
+ }
+
+ it should "be able to generate a time range value" in {
+
+ val generatedTimeRange = nextTimeRange()
+ val currentTime = System.currentTimeMillis()
+
+ generatedTimeRange.start.millis should be >= 0L
+ generatedTimeRange.start.millis should be <= currentTime
+ generatedTimeRange.end.millis should be >= 0L
+ generatedTimeRange.end.millis should be <= currentTime
+ generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis
+ }
+
+ it should "be able to generate a BigDecimal value" in {
+
+ val defaultGeneratedBigDecimal = nextBigDecimal()
+
+ defaultGeneratedBigDecimal should be >= BigDecimal(0.00)
+ defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00)
+ defaultGeneratedBigDecimal.precision should be(2)
+
+ val unitIntervalBigDecimal = nextBigDecimal(1.00, 8)
+
+ unitIntervalBigDecimal should be >= BigDecimal(0.00)
+ unitIntervalBigDecimal should be <= BigDecimal(1.00)
+ unitIntervalBigDecimal.precision should be(8)
+ }
+
+ it should "be able to generate a specific value from a set of values" in {
+
+ val possibleOptions = Set(1, 3, 5, 123, 0, 9)
+
+ val pick1 = generators.oneOf(possibleOptions)
+ val pick2 = generators.oneOf(possibleOptions)
+ val pick3 = generators.oneOf(possibleOptions)
+
+ possibleOptions should contain(pick1)
+ possibleOptions should contain(pick2)
+ possibleOptions should contain(pick3)
+
+ val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9)
+ val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9)
+ val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9)
+
+ possibleOptions should contain(pick4)
+ possibleOptions should contain(pick5)
+ possibleOptions should contain(pick6)
+
+ Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1
+ }
+
+ it should "be able to generate a specific value from an enumeratum enum" in {
+
+ import enumeratum._
+ sealed trait TestEnumValue extends EnumEntry
+ object TestEnum extends Enum[TestEnumValue] {
+ case object Value1 extends TestEnumValue
+ case object Value2 extends TestEnumValue
+ case object Value3 extends TestEnumValue
+ case object Value4 extends TestEnumValue
+ val values: IndexedSeq[TestEnumValue] = findValues
+ }
+
+ val picks = (1 to 100).map(_ => generators.oneOf(TestEnum))
+
+ TestEnum.values should contain allElementsOf picks
+ picks.toSet.size should be >= 1
+ }
+
+ it should "be able to generate array with values generated by generators" in {
+
+ val arrayOfTimes = arrayOf(nextTime(), 16)
+ arrayOfTimes.length should be <= 16
+
+ val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8)
+ arrayOfBigDecimals.length should be <= 8
+ }
+
+ it should "be able to generate seq with values generated by generators" in {
+
+ val seqOfTimes = seqOf(nextTime(), 16)
+ seqOfTimes.size should be <= 16
+
+ val seqOfBigDecimals = seqOf(nextBigDecimal(), 8)
+ seqOfBigDecimals.size should be <= 8
+ }
+
+ it should "be able to generate vector with values generated by generators" in {
+
+ val vectorOfTimes = vectorOf(nextTime(), 16)
+ vectorOfTimes.size should be <= 16
+
+ val vectorOfStrings = seqOf(nextString(), 8)
+ vectorOfStrings.size should be <= 8
+ }
+
+ it should "be able to generate list with values generated by generators" in {
+
+ val listOfTimes = listOf(nextTime(), 16)
+ listOfTimes.size should be <= 16
+
+ val listOfBigDecimals = seqOf(nextBigDecimal(), 8)
+ listOfBigDecimals.size should be <= 8
+ }
+
+ it should "be able to generate set with values generated by generators" in {
+
+ val setOfTimes = vectorOf(nextTime(), 16)
+ setOfTimes.size should be <= 16
+
+ val setOfBigDecimals = seqOf(nextBigDecimal(), 8)
+ setOfBigDecimals.size should be <= 8
+ }
+
+ it should "be able to generate maps with keys and values generated by generators" in {
+
+ val generatedConstantMap = mapOf("key", 123, 10)
+ generatedConstantMap.size should be <= 1
+ assert(generatedConstantMap.keys.forall(_ == "key"))
+ assert(generatedConstantMap.values.forall(_ == 123))
+
+ val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10)
+ assert(generatedMap.keys.forall(_.length <= 10))
+ assert(generatedMap.values.forall(_ >= BigDecimal(0.00)))
+ }
+
+ it should "compose deeply" in {
+
+ val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10)
+
+ generatedNestedMap.size should be <= 10
+ generatedNestedMap.keySet.size should be <= 10
+ generatedNestedMap.values.size should be <= 10
+ assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123)))
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala
new file mode 100644
index 0000000..fd693f9
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala
@@ -0,0 +1,521 @@
+package xyz.driver.core
+
+import java.net.InetAddress
+import java.time.{Instant, LocalDate}
+
+import akka.http.scaladsl.model.Uri
+import akka.http.scaladsl.server.PathMatcher
+import akka.http.scaladsl.server.PathMatcher.Matched
+import com.neovisionaries.i18n.{CountryCode, CurrencyCode}
+import enumeratum._
+import eu.timepit.refined.collection.NonEmpty
+import eu.timepit.refined.numeric.Positive
+import eu.timepit.refined.refineMV
+import org.scalatest.{Inspectors, Matchers, WordSpec}
+import spray.json._
+import xyz.driver.core.TestTypes.CustomGADT
+import xyz.driver.core.auth.AuthCredentials
+import xyz.driver.core.domain.{Email, PhoneNumber}
+import xyz.driver.core.json._
+import xyz.driver.core.json.enumeratum.HasJsonFormat
+import xyz.driver.core.tagging._
+import xyz.driver.core.time.provider.SystemTimeProvider
+import xyz.driver.core.time.{Time, TimeOfDay}
+
+import scala.collection.immutable.IndexedSeq
+import scala.language.postfixOps
+
+class JsonTest extends WordSpec with Matchers with Inspectors {
+ import DefaultJsonProtocol._
+
+ "Json format for Id" should {
+ "read and write correct JSON" in {
+
+ val referenceId = Id[String]("1312-34A")
+
+ val writtenJson = json.idFormat.write(referenceId)
+ writtenJson.prettyPrint should be("\"1312-34A\"")
+
+ val parsedId = json.idFormat.read(writtenJson)
+ parsedId should be(referenceId)
+ }
+ }
+
+ "Json format for @@" should {
+ "read and write correct JSON" in {
+ trait Irrelevant
+ val reference = Id[JsonTest]("SomeID").tagged[Irrelevant]
+
+ val format = json.taggedFormat[Id[JsonTest], Irrelevant]
+
+ val writtenJson = format.write(reference)
+ writtenJson shouldBe JsString("SomeID")
+
+ val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson)
+ parsedId shouldBe reference
+ }
+
+ "read and write correct JSON when there's an implicit conversion defined" in {
+ val input = " some string "
+
+ JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim()
+
+ val trimmed: String @@ Trimmed = input
+ trimmed.toJson shouldBe JsString(trimmed)
+ }
+ }
+
+ "Json format for Name" should {
+ "read and write correct JSON" in {
+
+ val referenceName = Name[String]("Homer")
+
+ val writtenJson = json.nameFormat.write(referenceName)
+ writtenJson.prettyPrint should be("\"Homer\"")
+
+ val parsedName = json.nameFormat.read(writtenJson)
+ parsedName should be(referenceName)
+ }
+
+ "read and write correct JSON for Name @@ Trimmed" in {
+ trait Irrelevant
+ JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name")
+
+ val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ")
+ trimmed.toJson shouldBe JsString("some name")
+ }
+ }
+
+ "Json format for NonEmptyName" should {
+ "read and write correct JSON" in {
+
+ val jsonFormat = json.nonEmptyNameFormat[String]
+
+ val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer"))
+
+ val writtenJson = jsonFormat.write(referenceNonEmptyName)
+ writtenJson.prettyPrint should be("\"Homer\"")
+
+ val parsedNonEmptyName = jsonFormat.read(writtenJson)
+ parsedNonEmptyName should be(referenceNonEmptyName)
+ }
+ }
+
+ "Json format for Time" should {
+ "read and write correct JSON" in {
+
+ val referenceTime = new SystemTimeProvider().currentTime()
+
+ val writtenJson = json.timeFormat.write(referenceTime)
+ writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}")
+
+ val parsedTime = json.timeFormat.read(writtenJson)
+ parsedTime should be(referenceTime)
+ }
+
+ "read from inputs compatible with Instant" in {
+ val referenceTime = new SystemTimeProvider().currentTime()
+
+ val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString))
+
+ forAll(jsons) { json =>
+ json.convertTo[Time] shouldBe referenceTime
+ }
+ }
+ }
+
+ "Json format for TimeOfDay" should {
+ "read and write correct JSON" in {
+ val utcTimeZone = java.util.TimeZone.getTimeZone("UTC")
+ val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00")
+ val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay)
+ writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson)
+ val parsed = json.timeOfDayFormat.read(writtenJson)
+ parsed should be(referenceTimeOfDay)
+ }
+ }
+
+ "Json format for Date" should {
+ "read and write correct JSON" in {
+ import date._
+
+ val referenceDate = Date(1941, Month.DECEMBER, 7)
+
+ val writtenJson = json.dateFormat.write(referenceDate)
+ writtenJson.prettyPrint should be("\"1941-12-07\"")
+
+ val parsedDate = json.dateFormat.read(writtenJson)
+ parsedDate should be(referenceDate)
+ }
+ }
+
+ "Json format for java.time.Instant" should {
+
+ val isoString = "2018-08-08T08:08:08.888Z"
+ val instant = Instant.parse(isoString)
+
+ "read correct JSON when value is an epoch milli number" in {
+ JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant
+ }
+
+ "read correct JSON when value is an ISO timestamp string" in {
+ JsString(isoString).convertTo[Instant] shouldBe instant
+ }
+
+ "read correct JSON when value is an object with nested 'timestamp'/millis field" in {
+ val json = JsObject(
+ "timestamp" -> JsNumber(instant.toEpochMilli)
+ )
+
+ json.convertTo[Instant] shouldBe instant
+ }
+
+ "write correct JSON" in {
+ instant.toJson shouldBe JsString(isoString)
+ }
+ }
+
+ "Path matcher for Instant" should {
+
+ val isoString = "2018-08-08T08:08:08.888Z"
+ val instant = Instant.parse(isoString)
+
+ val matcher = PathMatcher("foo") / InstantInPath /
+
+ "read instant from millis" in {
+ matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant))
+ }
+
+ "read instant from ISO timestamp string" in {
+ matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant))
+ }
+ }
+
+ "Json format for java.time.LocalDate" should {
+
+ "read and write correct JSON" in {
+ val dateString = "2018-08-08"
+ val date = LocalDate.parse(dateString)
+
+ date.toJson shouldBe JsString(dateString)
+ JsString(dateString).convertTo[LocalDate] shouldBe date
+ }
+ }
+
+ "Json format for Revision" should {
+ "read and write correct JSON" in {
+
+ val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4")
+
+ val writtenJson = json.revisionFormat.write(referenceRevision)
+ writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"")
+
+ val parsedRevision = json.revisionFormat.read(writtenJson)
+ parsedRevision should be(referenceRevision)
+ }
+ }
+
+ "Json format for Email" should {
+ "read and write correct JSON" in {
+
+ val referenceEmail = Email("test", "drivergrp.com")
+
+ val writtenJson = json.emailFormat.write(referenceEmail)
+ writtenJson should be("\"test@drivergrp.com\"".parseJson)
+
+ val parsedEmail = json.emailFormat.read(writtenJson)
+ parsedEmail should be(referenceEmail)
+ }
+ }
+
+ "Json format for PhoneNumber" should {
+ "read and write correct JSON" in {
+
+ val referencePhoneNumber = PhoneNumber("1", "4243039608")
+
+ val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber)
+ writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson)
+
+ val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson)
+ parsedPhoneNumber should be(referencePhoneNumber)
+ }
+
+ "reject an invalid phone number" in {
+ val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson
+
+ intercept[DeserializationException] {
+ json.phoneNumberFormat.read(phoneJson)
+ }.getMessage shouldBe "Invalid phone number"
+ }
+
+ "parse phone number from string" in {
+ JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608")
+ }
+ }
+
+ "Path matcher for PhoneNumber" should {
+ "read valid phone number" in {
+ val string = "+14243039608x23"
+ val phone = PhoneNumber("1", "4243039608", Some("23"))
+
+ val matcher = PathMatcher("foo") / PhoneInPath
+
+ matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone))
+ }
+ }
+
+ "Json format for ADT mappings" should {
+ "read and write correct JSON" in {
+
+ sealed trait EnumVal
+ case object Val1 extends EnumVal
+ case object Val2 extends EnumVal
+ case object Val3 extends EnumVal
+
+ val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3)
+
+ val referenceEnumValue1 = Val2
+ val referenceEnumValue2 = Val3
+
+ val writtenJson1 = format.write(referenceEnumValue1)
+ writtenJson1.prettyPrint should be("\"b\"")
+
+ val writtenJson2 = format.write(referenceEnumValue2)
+ writtenJson2.prettyPrint should be("\"c\"")
+
+ val parsedEnumValue1 = format.read(writtenJson1)
+ val parsedEnumValue2 = format.read(writtenJson2)
+
+ parsedEnumValue1 should be(referenceEnumValue1)
+ parsedEnumValue2 should be(referenceEnumValue2)
+ }
+ }
+
+ "Json format for Enums (external)" should {
+ "read and write correct JSON" in {
+
+ sealed trait MyEnum extends EnumEntry
+ object MyEnum extends Enum[MyEnum] {
+ case object Val1 extends MyEnum
+ case object `Val 2` extends MyEnum
+ case object `Val/3` extends MyEnum
+
+ val values: IndexedSeq[MyEnum] = findValues
+ }
+
+ val format = new enumeratum.EnumJsonFormat(MyEnum)
+
+ val referenceEnumValue1 = MyEnum.`Val 2`
+ val referenceEnumValue2 = MyEnum.`Val/3`
+
+ val writtenJson1 = format.write(referenceEnumValue1)
+ writtenJson1 shouldBe JsString("Val 2")
+
+ val writtenJson2 = format.write(referenceEnumValue2)
+ writtenJson2 shouldBe JsString("Val/3")
+
+ val parsedEnumValue1 = format.read(writtenJson1)
+ val parsedEnumValue2 = format.read(writtenJson2)
+
+ parsedEnumValue1 shouldBe referenceEnumValue1
+ parsedEnumValue2 shouldBe referenceEnumValue2
+
+ intercept[DeserializationException] {
+ format.read(JsString("Val4"))
+ }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]"
+ }
+ }
+
+ "Json format for Enums (automatic)" should {
+ "read and write correct JSON and not require import" in {
+
+ sealed trait MyEnum extends EnumEntry
+ object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] {
+ case object Val1 extends MyEnum
+ case object `Val 2` extends MyEnum
+ case object `Val/3` extends MyEnum
+
+ val values: IndexedSeq[MyEnum] = findValues
+ }
+
+ val referenceEnumValue1: MyEnum = MyEnum.`Val 2`
+ val referenceEnumValue2: MyEnum = MyEnum.`Val/3`
+
+ val writtenJson1 = referenceEnumValue1.toJson
+ writtenJson1 shouldBe JsString("Val 2")
+
+ val writtenJson2 = referenceEnumValue2.toJson
+ writtenJson2 shouldBe JsString("Val/3")
+
+ import spray.json._
+
+ val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum]
+ val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum]
+
+ parsedEnumValue1 should be(referenceEnumValue1)
+ parsedEnumValue2 should be(referenceEnumValue2)
+
+ intercept[DeserializationException] {
+ JsString("Val4").convertTo[MyEnum]
+ }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]"
+ }
+ }
+
+ // Should be defined outside of case to have a TypeTag
+ case class CustomWrapperClass(value: Int)
+
+ "Json format for Value classes" should {
+ "read and write correct JSON" in {
+
+ val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt))
+
+ val referenceValue1 = CustomWrapperClass(-2)
+ val referenceValue2 = CustomWrapperClass(10)
+
+ val writtenJson1 = format.write(referenceValue1)
+ writtenJson1.prettyPrint should be("-2")
+
+ val writtenJson2 = format.write(referenceValue2)
+ writtenJson2.prettyPrint should be("10")
+
+ val parsedValue1 = format.read(writtenJson1)
+ val parsedValue2 = format.read(writtenJson2)
+
+ parsedValue1 should be(referenceValue1)
+ parsedValue2 should be(referenceValue2)
+ }
+ }
+
+ "Json format for classes GADT" should {
+ "read and write correct JSON" in {
+
+ import CustomGADT._
+ import DefaultJsonProtocol._
+ implicit val case1Format = jsonFormat1(GadtCase1)
+ implicit val case2Format = jsonFormat1(GadtCase2)
+ implicit val case3Format = jsonFormat1(GadtCase3)
+
+ val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") {
+ case _: CustomGADT.GadtCase1 => "case1"
+ case _: CustomGADT.GadtCase2 => "case2"
+ case _: CustomGADT.GadtCase3 => "case3"
+ } {
+ case "case1" => case1Format
+ case "case2" => case2Format
+ case "case3" => case3Format
+ }
+
+ val referenceValue1 = CustomGADT.GadtCase1("4")
+ val referenceValue2 = CustomGADT.GadtCase2("Hi!")
+
+ val writtenJson1 = format.write(referenceValue1)
+ writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson)
+
+ val writtenJson2 = format.write(referenceValue2)
+ writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson)
+
+ val parsedValue1 = format.read(writtenJson1)
+ val parsedValue2 = format.read(writtenJson2)
+
+ parsedValue1 should be(referenceValue1)
+ parsedValue2 should be(referenceValue2)
+ }
+ }
+
+ "Json format for a Refined value" should {
+ "read and write correct JSON" in {
+
+ val jsonFormat = json.refinedJsonFormat[Int, Positive]
+
+ val referenceRefinedNumber = refineMV[Positive](42)
+
+ val writtenJson = jsonFormat.write(referenceRefinedNumber)
+ writtenJson should be("42".parseJson)
+
+ val parsedRefinedNumber = jsonFormat.read(writtenJson)
+ parsedRefinedNumber should be(referenceRefinedNumber)
+ }
+ }
+
+ "InetAddress format" should {
+ "read and write correct JSON" in {
+ val address = InetAddress.getByName("127.0.0.1")
+ val json = inetAddressFormat.write(address)
+
+ json shouldBe JsString("127.0.0.1")
+
+ val parsed = inetAddressFormat.read(json)
+ parsed shouldBe address
+ }
+
+ "throw a DeserializationException for an invalid IP Address" in {
+ assertThrows[DeserializationException] {
+ val invalidAddress = JsString("foobar:")
+ inetAddressFormat.read(invalidAddress)
+ }
+ }
+ }
+
+ "AuthCredentials format" should {
+ "read and write correct JSON" in {
+ val email = Email("someone", "noehere.com")
+ val phoneId = PhoneNumber.parse("1 207 8675309")
+ val password = "nopassword"
+
+ phoneId.isDefined should be(true) // test this real quick
+
+ val emailAuth = AuthCredentials(email.toString, password)
+ val pnAuth = AuthCredentials(phoneId.get.toString, password)
+
+ val emailWritten = authCredentialsFormat.write(emailAuth)
+ emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson)
+
+ val phoneWritten = authCredentialsFormat.write(pnAuth)
+ phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson)
+
+ val identifierEmailParsed =
+ authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson)
+ var written = authCredentialsFormat.write(identifierEmailParsed)
+ written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson)
+
+ val emailEmailParsed =
+ authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson)
+ written = authCredentialsFormat.write(emailEmailParsed)
+ written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson)
+
+ }
+ }
+
+ "CountryCode format" should {
+ "read and write correct JSON" in {
+ val samples = Seq(
+ "US" -> CountryCode.US,
+ "CN" -> CountryCode.CN,
+ "AT" -> CountryCode.AT
+ )
+
+ forAll(samples) {
+ case (serialized, enumValue) =>
+ countryCodeFormat.write(enumValue) shouldBe JsString(serialized)
+ countryCodeFormat.read(JsString(serialized)) shouldBe enumValue
+ }
+ }
+ }
+
+ "CurrencyCode format" should {
+ "read and write correct JSON" in {
+ val samples = Seq(
+ "USD" -> CurrencyCode.USD,
+ "CNY" -> CurrencyCode.CNY,
+ "EUR" -> CurrencyCode.EUR
+ )
+
+ forAll(samples) {
+ case (serialized, enumValue) =>
+ currencyCodeFormat.write(enumValue) shouldBe JsString(serialized)
+ currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue
+ }
+ }
+ }
+
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala
new file mode 100644
index 0000000..bb25deb
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala
@@ -0,0 +1,14 @@
+package xyz.driver.core
+
+object TestTypes {
+
+ sealed trait CustomGADT {
+ val field: String
+ }
+
+ object CustomGADT {
+ final case class GadtCase1(field: String) extends CustomGADT
+ final case class GadtCase2(field: String) extends CustomGADT
+ final case class GadtCase3(field: String) extends CustomGADT
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
new file mode 100644
index 0000000..324c8d8
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala
@@ -0,0 +1,89 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.model.{HttpMethod, StatusCodes}
+import akka.http.scaladsl.server.{Directives, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.config.ConfigFactory
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.app.{DriverApp, SimpleModule}
+
+class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives {
+ val config = ConfigFactory.parseString("""
+ |application {
+ | cors {
+ | allowedOrigins: ["example.com"]
+ | }
+ |}
+ """.stripMargin).withFallback(ConfigFactory.load)
+
+ val origin = Origin(HttpOrigin("https", Host("example.com")))
+ val allowedOrigins = Set(HttpOrigin("https", Host("example.com")))
+ val allowedMethods: collection.immutable.Seq[HttpMethod] = {
+ import akka.http.scaladsl.model.HttpMethods._
+ collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE)
+ }
+
+ import scala.reflect.runtime.universe.typeOf
+ class TestApp(testRoute: Route)
+ extends DriverApp(
+ appName = "test-app",
+ version = "0.0.1",
+ gitHash = "deadb33f",
+ modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])),
+ config = config,
+ log = xyz.driver.core.logging.NoLogger
+ )
+
+ it should "respond with the correct CORS headers for the swagger OPTIONS route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods
+ }
+ }
+
+ it should "respond with the correct CORS headers for the test route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ }
+ }
+
+ it should "respond with the correct CORS headers for a concatenated route" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK)))
+ Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*)))
+ }
+ }
+
+ it should "allow subdomains of allowed origin suffixes" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com"))))
+ }
+ }
+
+ it should "respond with default domains for invalid origins" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test")
+ .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*))
+ }
+ }
+
+ it should "respond with Pragma and Cache-Control (no-cache) headers" in {
+ val route = new TestApp(get(complete(StatusCodes.OK)))
+ Get(s"/api/v1/test") ~> route.appRoute ~> check {
+ status shouldBe StatusCodes.OK
+ header("Pragma").map(_.value()) should contain("no-cache")
+ header[`Cache-Control`].map(_.value()) should contain("no-cache")
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
new file mode 100644
index 0000000..cc0019a
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
@@ -0,0 +1,121 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.model.headers.Connection
+import akka.http.scaladsl.server.Directives.{complete => akkaComplete}
+import akka.http.scaladsl.server.{Directives, RejectionHandler, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.scalalogging.Logger
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.json.serviceExceptionFormat
+import xyz.driver.core.logging.NoLogger
+import xyz.driver.core.rest.errors._
+
+import scala.concurrent.Future
+
+class DriverRouteTest
+ extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives {
+ class TestRoute(override val route: Route) extends DriverRoute {
+ override def log: Logger = NoLogger
+ }
+
+ "DriverRoute" should "respond with 200 OK for a basic route" in {
+ val route = new TestRoute(akkaComplete(StatusCodes.OK))
+
+ Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.OK
+ }
+ }
+
+ it should "respond with a 401 for an InvalidInputException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ responseAs[ServiceException] shouldBe InvalidInputException()
+ }
+ }
+
+ it should "respond with a 403 for InvalidActionException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.Forbidden
+ responseAs[ServiceException] shouldBe InvalidActionException()
+ }
+ }
+
+ it should "respond with a 404 for ResourceNotFoundException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.NotFound
+ responseAs[ServiceException] shouldBe ResourceNotFoundException()
+ }
+ }
+
+ it should "respond with a 500 for ExternalServiceException" in {
+ val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None)
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[ServiceException] shouldBe error
+ }
+ }
+
+ it should "allow pass-through of external service exceptions" in {
+ val innerError = InvalidInputException()
+ val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError))
+ val future = Future.failed[String](error)
+ val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ responseAs[ServiceException] shouldBe innerError
+ }
+ }
+
+ it should "respond with a 503 for ExternalServiceTimeoutException" in {
+ val error = ExternalServiceTimeoutException("GET /api/v1/users/")
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.GatewayTimeout
+ responseAs[ServiceException] shouldBe error
+ }
+ }
+
+ it should "respond with a 500 for DatabaseException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[ServiceException] shouldBe DatabaseException()
+ }
+ }
+
+ it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in {
+ val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result()
+ val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK"))))
+
+ Get("/foo") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.OK
+ headers should contain(Connection("close"))
+ }
+
+ Get("/bar") ~> route.routeWithDefaults ~> check {
+ status shouldBe StatusCodes.NotFound
+ headers should contain(Connection("close"))
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
new file mode 100644
index 0000000..987717d
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
@@ -0,0 +1,101 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.headers.`Content-Type`
+import akka.http.scaladsl.server.{Directives, Route}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import org.scalatest.{FlatSpec, Matchers}
+import spray.json._
+import xyz.driver.core.{Id, Name}
+import xyz.driver.core.json._
+
+import scala.concurrent.Future
+
+class PatchDirectivesTest
+ extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol
+ with Directives with PatchDirectives {
+ case class Bar(name: Name[Bar], size: Int)
+ case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar])
+ implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar)
+ implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo)
+
+ val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10)))
+
+ def route(retrieve: => Future[Option[Foo]]): Route =
+ Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId =>
+ entity(as[Patchable[Foo]]) { fooPatchable =>
+ mergePatch(fooPatchable, retrieve) { updatedFoo =>
+ complete(updatedFoo)
+ }
+ }
+ })
+
+ val MergePatchContentType = ContentType(`application/merge-patch+json`)
+ val ContentTypeHeader = `Content-Type`(MergePatchContentType)
+ def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity =
+ HttpEntity(contentType, json)
+
+ "PatchSupport" should "allow partial updates to an existing object" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(rank = 4)
+ }
+ }
+
+ it should "merge deeply nested objects" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10)))
+ }
+ }
+
+ it should "return a 404 if the object is not found" in {
+ val fooRetrieve = Future.successful(None)
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.NotFound
+ }
+ }
+
+ it should "handle nulls on optional values correctly" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(bar = None)
+ }
+ }
+
+ it should "handle optional values correctly when old value is null" in {
+ val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None)))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10)))
+ }
+ }
+
+ it should "return a 400 for nulls on non-optional values" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ }
+ }
+
+ it should "return a 415 for incorrect Content-Type" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
+
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check {
+ status shouldBe StatusCodes.UnsupportedMediaType
+ responseAs[String] should include("application/merge-patch+json")
+ }
+ }
+}
diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala
new file mode 100644
index 0000000..19e4ed1
--- /dev/null
+++ b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala
@@ -0,0 +1,151 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.server.{Directives, Route, ValidationRejection}
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import akka.util.ByteString
+import org.scalatest.{Matchers, WordSpec}
+import xyz.driver.core.rest
+
+import scala.concurrent.Future
+import scala.util.Random
+
+class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives {
+ "`escapeScriptTags` function" should {
+ "escape script tags properly" in {
+ val dirtyString = "</sc----</sc----</sc"
+ val cleanString = "--------------------"
+
+ (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc"))
+
+ (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString)
+ }
+ }
+
+ "paginated directive" should {
+ val route: Route = rest.paginated { paginated =>
+ complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}")
+ }
+ "accept a pagination" in {
+ Get("/?pageNumber=2&pageSize=42") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "2,42")
+ }
+ }
+ "provide a default pagination" in {
+ Get("/") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "1,100")
+ }
+ }
+ "provide default values for a partial pagination" in {
+ Get("/?pageSize=2") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "1,2")
+ }
+ }
+ "reject an invalid pagination" in {
+ Get("/?pageNumber=-1") ~> route ~> check {
+ assert(rejection.isInstanceOf[ValidationRejection])
+ }
+ }
+ }
+
+ "optional paginated directive" should {
+ val route: Route = rest.optionalPagination { paginated =>
+ complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination"))
+ }
+ "accept a pagination" in {
+ Get("/?pageNumber=2&pageSize=42") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "2,42")
+ }
+ }
+ "without pagination" in {
+ Get("/") ~> route ~> check {
+ assert(status == StatusCodes.OK)
+ assert(entityAs[String] == "no pagination")
+ }
+ }
+ "reject an invalid pagination" in {
+ Get("/?pageNumber=1") ~> route ~> check {
+ assert(rejection.isInstanceOf[ValidationRejection])
+ }
+ }
+ }
+
+ "completeWithPagination directive" when {
+ import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
+ import spray.json.DefaultJsonProtocol._
+
+ val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString)
+ val route: Route =
+ parameter('empty.as[Boolean] ? false) { isEmpty =>
+ completeWithPagination[String] {
+ case Some(pagination) if isEmpty =>
+ Future.successful(ListResponse(Seq(), 0, Some(pagination)))
+ case Some(pagination) =>
+ val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize)
+ Future.successful(ListResponse(filtered, data.size, Some(pagination)))
+ case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None))
+ case None => Future.successful(ListResponse(data, data.size, None))
+ }
+ }
+
+ "pagination is defined" should {
+ "return a response with pagination headers" in {
+ Get("/?pageNumber=2&pageSize=10") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe data.slice(10, 20)
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("103")
+ header(ContextHeaders.PageCount).map(_.value) should contain("11")
+ }
+ }
+
+ "disallow pageSize <= 0" in {
+ Get("/?pageNumber=2&pageSize=0") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+
+ Get("/?pageNumber=2&pageSize=-1") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+ }
+
+ "disallow pageNumber <= 0" in {
+ Get("/?pageNumber=0&pageSize=10") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+
+ Get("/?pageNumber=-1&pageSize=10") ~> route ~> check {
+ rejection shouldBe a[ValidationRejection]
+ }
+ }
+
+ "return PageCount == 0 if returning an empty list" in {
+ Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe empty
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("0")
+ header(ContextHeaders.PageCount).map(_.value) should contain("0")
+ }
+ }
+ }
+
+ "pagination is not defined" should {
+ "return a response with pagination headers and PageCount == 1" in {
+ Get("/") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe data
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("103")
+ header(ContextHeaders.PageCount).map(_.value) should contain("1")
+ }
+ }
+
+ "return PageCount == 0 if returning an empty list" in {
+ Get("/?empty=true") ~> route ~> check {
+ responseAs[Seq[String]] shouldBe empty
+ header(ContextHeaders.ResourceCount).map(_.value) should contain("0")
+ header(ContextHeaders.PageCount).map(_.value) should contain("0")
+ }
+ }
+ }
+ }
+}