From 4d1197099ce4e721c18bf4cacbb2e1980e4210b5 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 12 Sep 2018 16:40:57 -0700 Subject: Move REST functionality to separate project --- .../src/test/scala/xyz/driver/core/AuthTest.scala | 165 +++++++ .../scala/xyz/driver/core/GeneratorsTest.scala | 264 +++++++++++ .../src/test/scala/xyz/driver/core/JsonTest.scala | 521 +++++++++++++++++++++ .../src/test/scala/xyz/driver/core/TestTypes.scala | 14 + .../scala/xyz/driver/core/rest/DriverAppTest.scala | 89 ++++ .../xyz/driver/core/rest/DriverRouteTest.scala | 121 +++++ .../xyz/driver/core/rest/PatchDirectivesTest.scala | 101 ++++ .../test/scala/xyz/driver/core/rest/RestTest.scala | 151 ++++++ 8 files changed, 1426 insertions(+) create mode 100644 core-rest/src/test/scala/xyz/driver/core/AuthTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/JsonTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/TestTypes.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala (limited to 'core-rest/src/test') diff --git a/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala new file mode 100644 index 0000000..2e772fb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala @@ -0,0 +1,165 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.headers.{ + HttpChallenges, + OAuth2BearerToken, + RawHeader, + Authorization => AkkaAuthorization +} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth._ +import xyz.driver.core.domain.Email +import xyz.driver.core.logging._ +import xyz.driver.core.rest._ +import xyz.driver.core.rest.auth._ +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.OptionT + +class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { + + case object TestRoleAllowedPermission extends Permission + case object TestRoleAllowedByTokenPermission extends Permission + case object TestRoleNotAllowedPermission extends Permission + + val TestRole = Role(Id("1"), Name("testRole")) + + val (publicKey, privateKey) = { + import java.security.KeyPairGenerator + + val keygen = KeyPairGenerator.getInstance("RSA") + keygen.initialize(2048) + + val keyPair = keygen.generateKeyPair() + (keyPair.getPublic, keyPair.getPrivate) + } + + val basicAuthorization: Authorization[User] = new Authorization[User] { + + override def userHasPermissions(user: User, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) + } + } + + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) + + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + + val authStatusService = new AuthProvider[User](authorization, NoLogger) { + override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = + OptionT.optionT[Future] { + if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { + Future.successful( + Some( + AuthTokenUserInfo( + Id[User]("1"), + Email("foo", "bar"), + emailVerified = true, + audience = "driver", + roles = Set(TestRole), + expirationTime = Time(1000000L) + ))) + } else { + Future.successful(Option.empty[User]) + } + } + } + + import authStatusService._ + + "'authorize' directive" should "throw error if auth token is not in the request" in { + + Get("/naive/attempt") ~> + authorize(TestRoleAllowedPermission) { user => + complete("Never going to be here") + } ~> + check { + // handled shouldBe false + rejections should contain( + AuthenticationFailedRejection( + AuthenticationFailedRejection.CredentialsMissing, + HttpChallenges.oAuth2(authStatusService.realm))) + } + } + + it should "throw error if authorized user does not have the requested permission" in { + + val referenceAuthToken = AuthToken("I am a test role's token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Post("/administration/attempt").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleNotAllowedPermission) { user => + complete("Never going to get here") + } ~> + check { + handled shouldBe false + rejections should contain(AuthorizationFailedRejection) + } + } + + it should "pass and retrieve the token to client code, if token is in request and user has permission" in { + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/valid/attempt/?a=2&b=5").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in { + val referenceAuthToken = AuthToken("unprefixed_token") + + Get("/valid/attempt/?a=2&b=5").addHeader( + RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value) + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authorize permission found in permissions token" in { + import spray.json._ + + val claim = JsObject( + Map( + "iss" -> JsString(tokenIssuer), + "sub" -> JsString("1"), + "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) + )).prettyPrint + val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/alic/attempt/?a=2&b=5") + .addHeader(referenceAuthHeader) + .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> + authorize(TestRoleAllowedByTokenPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala new file mode 100644 index 0000000..7e740a4 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala @@ -0,0 +1,264 @@ +package xyz.driver.core + +import org.scalatest.{Assertions, FlatSpec, Matchers} + +import scala.collection.immutable.IndexedSeq + +class GeneratorsTest extends FlatSpec with Matchers with Assertions { + import generators._ + + "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { + + val generatedId1 = nextId[String]() + val generatedId2 = nextId[String]() + val generatedId3 = nextId[Long]() + + generatedId1.length should be >= 0 + generatedId2.length should be >= 0 + generatedId3.length should be >= 0 + generatedId1 should not be generatedId2 + generatedId2 should !==(generatedId3) + } + + it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { + + val generatedLimitedId1 = nextId[String](5) + val generatedLimitedId2 = nextId[String](4) + val generatedLimitedId3 = nextId[Long](3) + + generatedLimitedId1.length should be >= 0 + generatedLimitedId1.length should be < 6 + generatedLimitedId2.length should be >= 0 + generatedLimitedId2.length should be < 5 + generatedLimitedId3.length should be >= 0 + generatedLimitedId3.length should be < 4 + generatedLimitedId1 should not be generatedLimitedId2 + generatedLimitedId2 should !==(generatedLimitedId3) + } + + it should "be able to generate com.drivergrp.core.Name names" in { + + Seq.fill(10)(nextName[String]()).distinct.size should be > 1 + nextName[String]().value.length should be >= 0 + + val fixedLengthName = nextName[String](10) + fixedLengthName.length should be <= 10 + assert(!fixedLengthName.value.exists(_.isControl)) + } + + it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { + + assert(nextNonEmptyName[String]().value.value.nonEmpty) + } + + it should "be able to generate proper UUIDs" in { + + nextUuid() should not be nextUuid() + nextUuid().toString.length should be(36) + } + + it should "be able to generate new Revisions" in { + + nextRevision[String]() should not be nextRevision[String]() + nextRevision[String]().id.length should be > 0 + } + + it should "be able to generate strings" in { + + nextString() should not be nextString() + nextString().length should be >= 0 + + val fixedLengthString = nextString(20) + fixedLengthString.length should be <= 20 + assert(!fixedLengthString.exists(_.isControl)) + } + + it should "be able to generate strings non-empty strings whic are non empty" in { + + assert(nextNonEmptyString().value.nonEmpty) + } + + it should "be able to generate options which are sometimes have values and sometimes not" in { + + val generatedOption = nextOption("2") + + generatedOption should not contain "1" + assert(generatedOption === Some("2") || generatedOption === None) + } + + it should "be able to generate a pair of two generated values" in { + + val constantPair = nextPair("foo", 1L) + constantPair._1 should be("foo") + constantPair._2 should be(1L) + + val generatedPair = nextPair(nextId[Int](), nextName[Int]()) + + generatedPair._1.length should be > 0 + generatedPair._2.length should be > 0 + + nextPair(nextId[Int](), nextName[Int]()) should not be + nextPair(nextId[Int](), nextName[Int]()) + } + + it should "be able to generate a triad of two generated values" in { + + val constantTriad = nextTriad("foo", "bar", 1L) + constantTriad._1 should be("foo") + constantTriad._2 should be("bar") + constantTriad._3 should be(1L) + + val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + + generatedTriad._1.length should be > 0 + generatedTriad._2.length should be > 0 + generatedTriad._3 should be >= BigDecimal(0.00) + + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + } + + it should "be able to generate a time value" in { + + val generatedTime = nextTime() + val currentTime = System.currentTimeMillis() + + generatedTime.millis should be >= 0L + generatedTime.millis should be <= currentTime + } + + it should "be able to generate a time range value" in { + + val generatedTimeRange = nextTimeRange() + val currentTime = System.currentTimeMillis() + + generatedTimeRange.start.millis should be >= 0L + generatedTimeRange.start.millis should be <= currentTime + generatedTimeRange.end.millis should be >= 0L + generatedTimeRange.end.millis should be <= currentTime + generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis + } + + it should "be able to generate a BigDecimal value" in { + + val defaultGeneratedBigDecimal = nextBigDecimal() + + defaultGeneratedBigDecimal should be >= BigDecimal(0.00) + defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) + defaultGeneratedBigDecimal.precision should be(2) + + val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) + + unitIntervalBigDecimal should be >= BigDecimal(0.00) + unitIntervalBigDecimal should be <= BigDecimal(1.00) + unitIntervalBigDecimal.precision should be(8) + } + + it should "be able to generate a specific value from a set of values" in { + + val possibleOptions = Set(1, 3, 5, 123, 0, 9) + + val pick1 = generators.oneOf(possibleOptions) + val pick2 = generators.oneOf(possibleOptions) + val pick3 = generators.oneOf(possibleOptions) + + possibleOptions should contain(pick1) + possibleOptions should contain(pick2) + possibleOptions should contain(pick3) + + val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) + + possibleOptions should contain(pick4) + possibleOptions should contain(pick5) + possibleOptions should contain(pick6) + + Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 + } + + it should "be able to generate a specific value from an enumeratum enum" in { + + import enumeratum._ + sealed trait TestEnumValue extends EnumEntry + object TestEnum extends Enum[TestEnumValue] { + case object Value1 extends TestEnumValue + case object Value2 extends TestEnumValue + case object Value3 extends TestEnumValue + case object Value4 extends TestEnumValue + val values: IndexedSeq[TestEnumValue] = findValues + } + + val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) + + TestEnum.values should contain allElementsOf picks + picks.toSet.size should be >= 1 + } + + it should "be able to generate array with values generated by generators" in { + + val arrayOfTimes = arrayOf(nextTime(), 16) + arrayOfTimes.length should be <= 16 + + val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) + arrayOfBigDecimals.length should be <= 8 + } + + it should "be able to generate seq with values generated by generators" in { + + val seqOfTimes = seqOf(nextTime(), 16) + seqOfTimes.size should be <= 16 + + val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) + seqOfBigDecimals.size should be <= 8 + } + + it should "be able to generate vector with values generated by generators" in { + + val vectorOfTimes = vectorOf(nextTime(), 16) + vectorOfTimes.size should be <= 16 + + val vectorOfStrings = seqOf(nextString(), 8) + vectorOfStrings.size should be <= 8 + } + + it should "be able to generate list with values generated by generators" in { + + val listOfTimes = listOf(nextTime(), 16) + listOfTimes.size should be <= 16 + + val listOfBigDecimals = seqOf(nextBigDecimal(), 8) + listOfBigDecimals.size should be <= 8 + } + + it should "be able to generate set with values generated by generators" in { + + val setOfTimes = vectorOf(nextTime(), 16) + setOfTimes.size should be <= 16 + + val setOfBigDecimals = seqOf(nextBigDecimal(), 8) + setOfBigDecimals.size should be <= 8 + } + + it should "be able to generate maps with keys and values generated by generators" in { + + val generatedConstantMap = mapOf("key", 123, 10) + generatedConstantMap.size should be <= 1 + assert(generatedConstantMap.keys.forall(_ == "key")) + assert(generatedConstantMap.values.forall(_ == 123)) + + val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) + assert(generatedMap.keys.forall(_.length <= 10)) + assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) + } + + it should "compose deeply" in { + + val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) + + generatedNestedMap.size should be <= 10 + generatedNestedMap.keySet.size should be <= 10 + generatedNestedMap.values.size should be <= 10 + assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala new file mode 100644 index 0000000..fd693f9 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala @@ -0,0 +1,521 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.time.{Instant, LocalDate} + +import akka.http.scaladsl.model.Uri +import akka.http.scaladsl.server.PathMatcher +import akka.http.scaladsl.server.PathMatcher.Matched +import com.neovisionaries.i18n.{CountryCode, CurrencyCode} +import enumeratum._ +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.numeric.Positive +import eu.timepit.refined.refineMV +import org.scalatest.{Inspectors, Matchers, WordSpec} +import spray.json._ +import xyz.driver.core.TestTypes.CustomGADT +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.json._ +import xyz.driver.core.json.enumeratum.HasJsonFormat +import xyz.driver.core.tagging._ +import xyz.driver.core.time.provider.SystemTimeProvider +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.collection.immutable.IndexedSeq +import scala.language.postfixOps + +class JsonTest extends WordSpec with Matchers with Inspectors { + import DefaultJsonProtocol._ + + "Json format for Id" should { + "read and write correct JSON" in { + + val referenceId = Id[String]("1312-34A") + + val writtenJson = json.idFormat.write(referenceId) + writtenJson.prettyPrint should be("\"1312-34A\"") + + val parsedId = json.idFormat.read(writtenJson) + parsedId should be(referenceId) + } + } + + "Json format for @@" should { + "read and write correct JSON" in { + trait Irrelevant + val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] + + val format = json.taggedFormat[Id[JsonTest], Irrelevant] + + val writtenJson = format.write(reference) + writtenJson shouldBe JsString("SomeID") + + val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) + parsedId shouldBe reference + } + + "read and write correct JSON when there's an implicit conversion defined" in { + val input = " some string " + + JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim() + + val trimmed: String @@ Trimmed = input + trimmed.toJson shouldBe JsString(trimmed) + } + } + + "Json format for Name" should { + "read and write correct JSON" in { + + val referenceName = Name[String]("Homer") + + val writtenJson = json.nameFormat.write(referenceName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedName = json.nameFormat.read(writtenJson) + parsedName should be(referenceName) + } + + "read and write correct JSON for Name @@ Trimmed" in { + trait Irrelevant + JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name") + + val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ") + trimmed.toJson shouldBe JsString("some name") + } + } + + "Json format for NonEmptyName" should { + "read and write correct JSON" in { + + val jsonFormat = json.nonEmptyNameFormat[String] + + val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) + + val writtenJson = jsonFormat.write(referenceNonEmptyName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedNonEmptyName = jsonFormat.read(writtenJson) + parsedNonEmptyName should be(referenceNonEmptyName) + } + } + + "Json format for Time" should { + "read and write correct JSON" in { + + val referenceTime = new SystemTimeProvider().currentTime() + + val writtenJson = json.timeFormat.write(referenceTime) + writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") + + val parsedTime = json.timeFormat.read(writtenJson) + parsedTime should be(referenceTime) + } + + "read from inputs compatible with Instant" in { + val referenceTime = new SystemTimeProvider().currentTime() + + val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString)) + + forAll(jsons) { json => + json.convertTo[Time] shouldBe referenceTime + } + } + } + + "Json format for TimeOfDay" should { + "read and write correct JSON" in { + val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") + val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") + val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) + writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) + val parsed = json.timeOfDayFormat.read(writtenJson) + parsed should be(referenceTimeOfDay) + } + } + + "Json format for Date" should { + "read and write correct JSON" in { + import date._ + + val referenceDate = Date(1941, Month.DECEMBER, 7) + + val writtenJson = json.dateFormat.write(referenceDate) + writtenJson.prettyPrint should be("\"1941-12-07\"") + + val parsedDate = json.dateFormat.read(writtenJson) + parsedDate should be(referenceDate) + } + } + + "Json format for java.time.Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + "read correct JSON when value is an epoch milli number" in { + JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an ISO timestamp string" in { + JsString(isoString).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an object with nested 'timestamp'/millis field" in { + val json = JsObject( + "timestamp" -> JsNumber(instant.toEpochMilli) + ) + + json.convertTo[Instant] shouldBe instant + } + + "write correct JSON" in { + instant.toJson shouldBe JsString(isoString) + } + } + + "Path matcher for Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + val matcher = PathMatcher("foo") / InstantInPath / + + "read instant from millis" in { + matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + + "read instant from ISO timestamp string" in { + matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + } + + "Json format for java.time.LocalDate" should { + + "read and write correct JSON" in { + val dateString = "2018-08-08" + val date = LocalDate.parse(dateString) + + date.toJson shouldBe JsString(dateString) + JsString(dateString).convertTo[LocalDate] shouldBe date + } + } + + "Json format for Revision" should { + "read and write correct JSON" in { + + val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") + + val writtenJson = json.revisionFormat.write(referenceRevision) + writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") + + val parsedRevision = json.revisionFormat.read(writtenJson) + parsedRevision should be(referenceRevision) + } + } + + "Json format for Email" should { + "read and write correct JSON" in { + + val referenceEmail = Email("test", "drivergrp.com") + + val writtenJson = json.emailFormat.write(referenceEmail) + writtenJson should be("\"test@drivergrp.com\"".parseJson) + + val parsedEmail = json.emailFormat.read(writtenJson) + parsedEmail should be(referenceEmail) + } + } + + "Json format for PhoneNumber" should { + "read and write correct JSON" in { + + val referencePhoneNumber = PhoneNumber("1", "4243039608") + + val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) + writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) + + val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) + parsedPhoneNumber should be(referencePhoneNumber) + } + + "reject an invalid phone number" in { + val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson + + intercept[DeserializationException] { + json.phoneNumberFormat.read(phoneJson) + }.getMessage shouldBe "Invalid phone number" + } + + "parse phone number from string" in { + JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608") + } + } + + "Path matcher for PhoneNumber" should { + "read valid phone number" in { + val string = "+14243039608x23" + val phone = PhoneNumber("1", "4243039608", Some("23")) + + val matcher = PathMatcher("foo") / PhoneInPath + + matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone)) + } + } + + "Json format for ADT mappings" should { + "read and write correct JSON" in { + + sealed trait EnumVal + case object Val1 extends EnumVal + case object Val2 extends EnumVal + case object Val3 extends EnumVal + + val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) + + val referenceEnumValue1 = Val2 + val referenceEnumValue2 = Val3 + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1.prettyPrint should be("\"b\"") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2.prettyPrint should be("\"c\"") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + } + } + + "Json format for Enums (external)" should { + "read and write correct JSON" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val format = new enumeratum.EnumJsonFormat(MyEnum) + + val referenceEnumValue1 = MyEnum.`Val 2` + val referenceEnumValue2 = MyEnum.`Val/3` + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2 shouldBe JsString("Val/3") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 shouldBe referenceEnumValue1 + parsedEnumValue2 shouldBe referenceEnumValue2 + + intercept[DeserializationException] { + format.read(JsString("Val4")) + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + "Json format for Enums (automatic)" should { + "read and write correct JSON and not require import" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val referenceEnumValue1: MyEnum = MyEnum.`Val 2` + val referenceEnumValue2: MyEnum = MyEnum.`Val/3` + + val writtenJson1 = referenceEnumValue1.toJson + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = referenceEnumValue2.toJson + writtenJson2 shouldBe JsString("Val/3") + + import spray.json._ + + val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] + val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + + intercept[DeserializationException] { + JsString("Val4").convertTo[MyEnum] + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + // Should be defined outside of case to have a TypeTag + case class CustomWrapperClass(value: Int) + + "Json format for Value classes" should { + "read and write correct JSON" in { + + val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) + + val referenceValue1 = CustomWrapperClass(-2) + val referenceValue2 = CustomWrapperClass(10) + + val writtenJson1 = format.write(referenceValue1) + writtenJson1.prettyPrint should be("-2") + + val writtenJson2 = format.write(referenceValue2) + writtenJson2.prettyPrint should be("10") + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for classes GADT" should { + "read and write correct JSON" in { + + import CustomGADT._ + import DefaultJsonProtocol._ + implicit val case1Format = jsonFormat1(GadtCase1) + implicit val case2Format = jsonFormat1(GadtCase2) + implicit val case3Format = jsonFormat1(GadtCase3) + + val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { + case _: CustomGADT.GadtCase1 => "case1" + case _: CustomGADT.GadtCase2 => "case2" + case _: CustomGADT.GadtCase3 => "case3" + } { + case "case1" => case1Format + case "case2" => case2Format + case "case3" => case3Format + } + + val referenceValue1 = CustomGADT.GadtCase1("4") + val referenceValue2 = CustomGADT.GadtCase2("Hi!") + + val writtenJson1 = format.write(referenceValue1) + writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) + + val writtenJson2 = format.write(referenceValue2) + writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for a Refined value" should { + "read and write correct JSON" in { + + val jsonFormat = json.refinedJsonFormat[Int, Positive] + + val referenceRefinedNumber = refineMV[Positive](42) + + val writtenJson = jsonFormat.write(referenceRefinedNumber) + writtenJson should be("42".parseJson) + + val parsedRefinedNumber = jsonFormat.read(writtenJson) + parsedRefinedNumber should be(referenceRefinedNumber) + } + } + + "InetAddress format" should { + "read and write correct JSON" in { + val address = InetAddress.getByName("127.0.0.1") + val json = inetAddressFormat.write(address) + + json shouldBe JsString("127.0.0.1") + + val parsed = inetAddressFormat.read(json) + parsed shouldBe address + } + + "throw a DeserializationException for an invalid IP Address" in { + assertThrows[DeserializationException] { + val invalidAddress = JsString("foobar:") + inetAddressFormat.read(invalidAddress) + } + } + } + + "AuthCredentials format" should { + "read and write correct JSON" in { + val email = Email("someone", "noehere.com") + val phoneId = PhoneNumber.parse("1 207 8675309") + val password = "nopassword" + + phoneId.isDefined should be(true) // test this real quick + + val emailAuth = AuthCredentials(email.toString, password) + val pnAuth = AuthCredentials(phoneId.get.toString, password) + + val emailWritten = authCredentialsFormat.write(emailAuth) + emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) + + val phoneWritten = authCredentialsFormat.write(pnAuth) + phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) + + val identifierEmailParsed = + authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) + var written = authCredentialsFormat.write(identifierEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + val emailEmailParsed = + authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) + written = authCredentialsFormat.write(emailEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + } + } + + "CountryCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "US" -> CountryCode.US, + "CN" -> CountryCode.CN, + "AT" -> CountryCode.AT + ) + + forAll(samples) { + case (serialized, enumValue) => + countryCodeFormat.write(enumValue) shouldBe JsString(serialized) + countryCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + + "CurrencyCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "USD" -> CurrencyCode.USD, + "CNY" -> CurrencyCode.CNY, + "EUR" -> CurrencyCode.EUR + ) + + forAll(samples) { + case (serialized, enumValue) => + currencyCodeFormat.write(enumValue) shouldBe JsString(serialized) + currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + +} diff --git a/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala new file mode 100644 index 0000000..bb25deb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala @@ -0,0 +1,14 @@ +package xyz.driver.core + +object TestTypes { + + sealed trait CustomGADT { + val field: String + } + + object CustomGADT { + final case class GadtCase1(field: String) extends CustomGADT + final case class GadtCase2(field: String) extends CustomGADT + final case class GadtCase3(field: String) extends CustomGADT + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..324c8d8 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,89 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpMethod, StatusCodes} +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.ConfigFactory +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, SimpleModule} + +class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { + val config = ConfigFactory.parseString(""" + |application { + | cors { + | allowedOrigins: ["example.com"] + | } + |} + """.stripMargin).withFallback(ConfigFactory.load) + + val origin = Origin(HttpOrigin("https", Host("example.com"))) + val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) + val allowedMethods: collection.immutable.Seq[HttpMethod] = { + import akka.http.scaladsl.model.HttpMethods._ + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) + } + + import scala.reflect.runtime.universe.typeOf + class TestApp(testRoute: Route) + extends DriverApp( + appName = "test-app", + version = "0.0.1", + gitHash = "deadb33f", + modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), + config = config, + log = xyz.driver.core.logging.NoLogger + ) + + it should "respond with the correct CORS headers for the swagger OPTIONS route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for the test route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "respond with the correct CORS headers for a concatenated route" in { + val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) + Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "allow subdomains of allowed origin suffixes" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) + } + } + + it should "respond with default domains for invalid origins" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) + } + } + + it should "respond with Pragma and Cache-Control (no-cache) headers" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + header("Pragma").map(_.value()) should contain("no-cache") + header[`Cache-Control`].map(_.value()) should contain("no-cache") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala new file mode 100644 index 0000000..cc0019a --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala @@ -0,0 +1,121 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.Connection +import akka.http.scaladsl.server.Directives.{complete => akkaComplete} +import akka.http.scaladsl.server.{Directives, RejectionHandler, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.scalalogging.Logger +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.json.serviceExceptionFormat +import xyz.driver.core.logging.NoLogger +import xyz.driver.core.rest.errors._ + +import scala.concurrent.Future + +class DriverRouteTest + extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { + class TestRoute(override val route: Route) extends DriverRoute { + override def log: Logger = NoLogger + } + + "DriverRoute" should "respond with 200 OK for a basic route" in { + val route = new TestRoute(akkaComplete(StatusCodes.OK)) + + Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.OK + } + } + + it should "respond with a 401 for an InvalidInputException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe InvalidInputException() + } + } + + it should "respond with a 403 for InvalidActionException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.Forbidden + responseAs[ServiceException] shouldBe InvalidActionException() + } + } + + it should "respond with a 404 for ResourceNotFoundException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + responseAs[ServiceException] shouldBe ResourceNotFoundException() + } + } + + it should "respond with a 500 for ExternalServiceException" in { + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe error + } + } + + it should "allow pass-through of external service exceptions" in { + val innerError = InvalidInputException() + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) + val future = Future.failed[String](error) + val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe innerError + } + } + + it should "respond with a 503 for ExternalServiceTimeoutException" in { + val error = ExternalServiceTimeoutException("GET /api/v1/users/") + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.GatewayTimeout + responseAs[ServiceException] shouldBe error + } + } + + it should "respond with a 500 for DatabaseException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe DatabaseException() + } + } + + it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { + val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() + val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) + + Get("/foo") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.OK + headers should contain(Connection("close")) + } + + Get("/bar") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.NotFound + headers should contain(Connection("close")) + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..987717d --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,101 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "handle optional values correctly when old value is null" in { + val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala new file mode 100644 index 0000000..19e4ed1 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala @@ -0,0 +1,151 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.util.ByteString +import org.scalatest.{Matchers, WordSpec} +import xyz.driver.core.rest + +import scala.concurrent.Future +import scala.util.Random + +class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { + "`escapeScriptTags` function" should { + "escape script tags properly" in { + val dirtyString = " + complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "provide a default pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,100") + } + } + "provide default values for a partial pagination" in { + Get("/?pageSize=2") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,2") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=-1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "optional paginated directive" should { + val route: Route = rest.optionalPagination { paginated => + complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "without pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "no pagination") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "completeWithPagination directive" when { + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json.DefaultJsonProtocol._ + + val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString) + val route: Route = + parameter('empty.as[Boolean] ? false) { isEmpty => + completeWithPagination[String] { + case Some(pagination) if isEmpty => + Future.successful(ListResponse(Seq(), 0, Some(pagination))) + case Some(pagination) => + val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize) + Future.successful(ListResponse(filtered, data.size, Some(pagination))) + case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None)) + case None => Future.successful(ListResponse(data, data.size, None)) + } + } + + "pagination is defined" should { + "return a response with pagination headers" in { + Get("/?pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe data.slice(10, 20) + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("11") + } + } + + "disallow pageSize <= 0" in { + Get("/?pageNumber=2&pageSize=0") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=2&pageSize=-1") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "disallow pageNumber <= 0" in { + Get("/?pageNumber=0&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=-1&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + + "pagination is not defined" should { + "return a response with pagination headers and PageCount == 1" in { + Get("/") ~> route ~> check { + responseAs[Seq[String]] shouldBe data + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("1") + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + } +} -- cgit v1.2.3