diff options
Diffstat (limited to 'jvm/src/test/scala/xyz/driver')
16 files changed, 1941 insertions, 0 deletions
diff --git a/jvm/src/test/scala/xyz/driver/core/AuthTest.scala b/jvm/src/test/scala/xyz/driver/core/AuthTest.scala new file mode 100644 index 0000000..a7707aa --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/AuthTest.scala @@ -0,0 +1,145 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth._ +import xyz.driver.core.domain.Email +import xyz.driver.core.logging._ +import xyz.driver.core.rest._ +import xyz.driver.core.rest.auth._ +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.OptionT + +class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { + + case object TestRoleAllowedPermission extends Permission + case object TestRoleAllowedByTokenPermission extends Permission + case object TestRoleNotAllowedPermission extends Permission + + val TestRole = Role(Id("1"), Name("testRole")) + + val (publicKey, privateKey) = { + import java.security.KeyPairGenerator + + val keygen = KeyPairGenerator.getInstance("RSA") + keygen.initialize(2048) + + val keyPair = keygen.generateKeyPair() + (keyPair.getPublic, keyPair.getPrivate) + } + + val basicAuthorization: Authorization[User] = new Authorization[User] { + + override def userHasPermissions(user: User, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) + } + } + + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) + + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + + val authStatusService = new AuthProvider[User](authorization, NoLogger) { + override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = + OptionT.optionT[Future] { + if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { + Future.successful( + Some( + AuthTokenUserInfo( + Id[User]("1"), + Email("foo", "bar"), + emailVerified = true, + audience = "driver", + roles = Set(TestRole), + expirationTime = Time(1000000L) + ))) + } else { + Future.successful(Option.empty[User]) + } + } + } + + import authStatusService._ + + "'authorize' directive" should "throw error if auth token is not in the request" in { + + Get("/naive/attempt") ~> + authorize(TestRoleAllowedPermission) { user => + complete("Never going to be here") + } ~> + check { + // handled shouldBe false + val challenge = HttpChallenges.basic("Failed to authenticate user") + rejections should contain(AuthenticationFailedRejection(CredentialsRejected, challenge)) + } + } + + it should "throw error if authorized user does not have the requested permission" in { + + val referenceAuthToken = AuthToken("I am a test role's token") + + Post("/administration/attempt").addHeader( + RawHeader(AuthProvider.AuthenticationTokenHeader, referenceAuthToken.value) + ) ~> + authorize(TestRoleNotAllowedPermission) { user => + complete("Never going to get here") + } ~> + check { + handled shouldBe false + rejections should contain( + AuthenticationFailedRejection( + CredentialsRejected, + HttpChallenges.basic("User does not have the required permissions: TestRoleNotAllowedPermission"))) + } + } + + it should "pass and retrieve the token to client code, if token is in request and user has permission" in { + + val referenceAuthToken = AuthToken("I am token") + + Get("/valid/attempt/?a=2&b=5").addHeader( + RawHeader(AuthProvider.AuthenticationTokenHeader, referenceAuthToken.value) + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authorize permission found in permissions token" in { + import spray.json._ + + val claim = JsObject( + Map( + "iss" -> JsString(tokenIssuer), + "sub" -> JsString("1"), + "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) + )).prettyPrint + val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) + val referenceAuthToken = AuthToken("I am token") + + Get("/alic/attempt/?a=2&b=5") + .addHeader(RawHeader(AuthProvider.AuthenticationTokenHeader, referenceAuthToken.value)) + .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> + authorize(TestRoleAllowedByTokenPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" + } + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/BlobStorageTest.scala b/jvm/src/test/scala/xyz/driver/core/BlobStorageTest.scala new file mode 100644 index 0000000..637f9e0 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/BlobStorageTest.scala @@ -0,0 +1,93 @@ +package xyz.driver.core + +import java.nio.file.Files + +import akka.actor.ActorSystem +import akka.stream.ActorMaterializer +import akka.stream.scaladsl._ +import akka.util.ByteString +import org.scalatest._ +import org.scalatest.concurrent.ScalaFutures +import xyz.driver.core.storage.{BlobStorage, FileSystemBlobStorage} + +import scala.concurrent.Future +import scala.concurrent.duration._ + +class BlobStorageTest extends FlatSpec with ScalaFutures { + + implicit val patientce = PatienceConfig(timeout = 100.seconds) + + implicit val system = ActorSystem("blobstorage-test") + implicit val mat = ActorMaterializer() + import system.dispatcher + + def storageBehaviour(storage: BlobStorage) = { + val key = "foo" + val data = "hello world".getBytes + it should "upload data" in { + assert(storage.exists(key).futureValue === false) + assert(storage.uploadContent(key, data).futureValue === key) + assert(storage.exists(key).futureValue === true) + } + it should "download data" in { + val content = storage.content(key).futureValue + assert(content.isDefined) + assert(content.get === data) + } + it should "not download non-existing data" in { + assert(storage.content("bar").futureValue.isEmpty) + } + it should "overwrite an existing key" in { + val newData = "new string".getBytes("utf-8") + assert(storage.uploadContent(key, newData).futureValue === key) + assert(storage.content(key).futureValue.get === newData) + } + it should "upload a file" in { + val tmp = Files.createTempFile("testfile", "txt") + Files.write(tmp, data) + assert(storage.uploadFile(key, tmp).futureValue === key) + Files.delete(tmp) + } + it should "upload content" in { + val text = "foobar" + val src = Source + .single(text) + .map(l => ByteString(l)) + src.runWith(storage.upload(key).futureValue).futureValue + assert(storage.content(key).futureValue.map(_.toSeq) === Some("foobar".getBytes.toSeq)) + } + it should "delete content" in { + assert(storage.exists(key).futureValue) + storage.delete(key).futureValue + assert(!storage.exists(key).futureValue) + } + it should "download content" in { + storage.uploadContent(key, data) futureValue + val srcOpt = storage.download(key).futureValue + assert(srcOpt.isDefined) + val src = srcOpt.get + val content: Future[Array[Byte]] = src.runWith(Sink.fold(Array[Byte]())(_ ++ _)) + assert(content.futureValue === data) + } + it should "list keys" in { + assert(storage.list("").futureValue === Set(key)) + storage.uploadContent("a/a.txt", data).futureValue + storage.uploadContent("a/b", data).futureValue + storage.uploadContent("c/d", data).futureValue + storage.uploadContent("d", data).futureValue + assert(storage.list("").futureValue === Set(key, "a", "c", "d")) + assert(storage.list("a").futureValue === Set("a/a.txt", "a/b")) + assert(storage.list("a").futureValue === Set("a/a.txt", "a/b")) + assert(storage.list("c").futureValue === Set("c/d")) + } + it should "get valid URL" in { + assert(storage.exists(key).futureValue === true) + val fooUrl = storage.url(key).futureValue + assert(fooUrl.isDefined) + } + } + + "File system storage" should behave like storageBehaviour( + new FileSystemBlobStorage(Files.createTempDirectory("test"))) + +} diff --git a/jvm/src/test/scala/xyz/driver/core/CoreTest.scala b/jvm/src/test/scala/xyz/driver/core/CoreTest.scala new file mode 100644 index 0000000..d280d73 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/CoreTest.scala @@ -0,0 +1,84 @@ +package xyz.driver.core + +import java.io.ByteArrayOutputStream + +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar +import org.scalatest.{FlatSpec, Matchers} + +class CoreTest extends FlatSpec with Matchers with MockitoSugar { + + "'make' function" should "allow initialization for objects" in { + + val createdAndInitializedValue = make(new ByteArrayOutputStream(128)) { baos => + baos.write(Array(1.toByte, 1.toByte, 0.toByte)) + } + + createdAndInitializedValue.toByteArray should be(Array(1.toByte, 1.toByte, 0.toByte)) + } + + "'using' function" should "call close after performing action on resource" in { + + val baos = mock[ByteArrayOutputStream] + + using(baos /* usually new ByteArrayOutputStream(128) */ ) { baos => + baos.write(Array(1.toByte, 1.toByte, 0.toByte)) + } + + verify(baos).close() + } + + "Id" should "have equality and ordering working correctly" in { + + (Id[String]("1234213") === Id[String]("1234213")) should be(true) + (Id[String]("1234213") === Id[String]("213414")) should be(false) + (Id[String]("213414") === Id[String]("1234213")) should be(false) + + Seq(Id[String]("4"), Id[String]("3"), Id[String]("2"), Id[String]("1")).sorted should contain + theSameElementsInOrderAs(Seq(Id[String]("1"), Id[String]("2"), Id[String]("3"), Id[String]("4"))) + } + + it should "have type-safe conversions" in { + final case class X(id: Id[X]) + final case class Y(id: Id[Y]) + final case class Z(id: Id[Z]) + + implicit val xy = Id.Mapper[X, Y] + implicit val yz = Id.Mapper[Y, Z] + + // Test that implicit conversions work correctly + val x = X(Id("0")) + val y = Y(x.id) + val z = Z(y.id) + val y2 = Y(z.id) + val x2 = X(y2.id) + (x2 === x) should be(true) + (y2 === y) should be(true) + + // Test that type inferrence for explicit conversions work correctly + val yid = y.id + val xid = xy(yid) + val zid = yz(yid) + (xid: Id[X]) should be(zid: Id[Z]) + } + + "Name" should "have equality and ordering working correctly" in { + + (Name[String]("foo") === Name[String]("foo")) should be(true) + (Name[String]("foo") === Name[String]("bar")) should be(false) + (Name[String]("bar") === Name[String]("foo")) should be(false) + + Seq(Name[String]("d"), Name[String]("cc"), Name[String]("a"), Name[String]("bbb")).sorted should contain + theSameElementsInOrderAs(Seq(Name[String]("a"), Name[String]("bbb"), Name[String]("cc"), Name[String]("d"))) + } + + "Revision" should "have equality working correctly" in { + + val bla = Revision[String]("85569dab-a3dc-401b-9f95-d6fb4162674b") + val foo = Revision[String]("f54b3558-bdcd-4646-a14b-8beb11f6b7c4") + + (bla === bla) should be(true) + (bla === foo) should be(false) + (foo === bla) should be(false) + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/DateTest.scala b/jvm/src/test/scala/xyz/driver/core/DateTest.scala new file mode 100644 index 0000000..0432040 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/DateTest.scala @@ -0,0 +1,53 @@ +package xyz.driver.core + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.prop.Checkers +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.date.Date + +class DateTest extends FlatSpec with Matchers with Checkers { + val dateGenerator = for { + year <- Gen.choose(0, 3000) + month <- Gen.choose(0, 11) + day <- Gen.choose(1, 31) + } yield Date(year, date.Month(month), day) + implicit val arbitraryDate = Arbitrary[Date](dateGenerator) + + "Date" should "correctly convert to and from String" in { + + import xyz.driver.core.generators.nextDate + import date._ + + for (date <- 1 to 100 map (_ => nextDate())) { + Some(date) should be(Date.fromString(date.toString)) + } + } + + it should "have ordering defined correctly" in { + Seq( + Date.fromString("2013-05-10"), + Date.fromString("2020-02-15"), + Date.fromString("2017-03-05"), + Date.fromString("2013-05-12")).sorted should + contain theSameElementsInOrderAs Seq( + Date.fromString("2013-05-10"), + Date.fromString("2013-05-12"), + Date.fromString("2017-03-05"), + Date.fromString("2020-02-15")) + + check { dates: List[Date] => + dates.sorted.sliding(2).filter(_.size == 2).forall { + case Seq(a, b) => + if (a.year == b.year) { + if (a.month == b.month) { + a.day <= b.day + } else { + a.month < b.month + } + } else { + a.year < b.year + } + } + } + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/FileTest.scala b/jvm/src/test/scala/xyz/driver/core/FileTest.scala new file mode 100644 index 0000000..8728089 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/FileTest.scala @@ -0,0 +1,216 @@ +package xyz.driver.core + +import java.io.{File, FileInputStream} +import java.nio.file.Paths + +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.file.{FileSystemStorage, GcsStorage, S3Storage} + +import scala.concurrent.Await +import scala.concurrent.duration._ + +class FileTest extends FlatSpec with Matchers with MockitoSugar { + + "S3 Storage" should "create and download local files and do other operations" in { + import com.amazonaws.services.s3.AmazonS3 + import com.amazonaws.services.s3.model._ + import scala.collection.JavaConverters._ + + val tempDir = System.getProperty("java.io.tmpdir") + val sourceTestFile = generateTestLocalFile(tempDir) + val testFileName = "uploadTestFile" + + val randomFolderName = java.util.UUID.randomUUID().toString + val testDirPath = Paths.get(randomFolderName) + val testFilePath = Paths.get(randomFolderName, testFileName) + + val testBucket = Name[Bucket]("IamBucket") + + val s3PutMock = mock[PutObjectResult] + when(s3PutMock.getETag).thenReturn("IAmEtag") + + val s3ObjectSummaryMock = mock[S3ObjectSummary] + when(s3ObjectSummaryMock.getKey).thenReturn(testFileName) + when(s3ObjectSummaryMock.getETag).thenReturn("IAmEtag") + when(s3ObjectSummaryMock.getLastModified).thenReturn(new java.util.Date()) + + val s3ResultsMock = mock[ListObjectsV2Result] + when(s3ResultsMock.getNextContinuationToken).thenReturn("continuationToken") + when(s3ResultsMock.isTruncated).thenReturn( + false, // before file created it is empty (zero pages) + true, + false, // after file is uploaded it contains this one file (one page) + false) // after file is deleted it is empty (zero pages) again + when(s3ResultsMock.getObjectSummaries).thenReturn( + // before file created it is empty, `getObjectSummaries` is never called + List[S3ObjectSummary](s3ObjectSummaryMock).asJava, // after file is uploaded it contains this one file + List.empty[S3ObjectSummary].asJava + ) // after file is deleted it is empty again + + val s3ObjectMetadataMock = mock[ObjectMetadata] + val amazonS3Mock = mock[AmazonS3] + when(amazonS3Mock.listObjectsV2(any[ListObjectsV2Request]())).thenReturn(s3ResultsMock) + when(amazonS3Mock.putObject(testBucket.value, testFilePath.toString, sourceTestFile)).thenReturn(s3PutMock) + when(amazonS3Mock.getObject(any[GetObjectRequest](), any[File]())).thenReturn(s3ObjectMetadataMock) + when(amazonS3Mock.doesObjectExist(testBucket.value, testFilePath.toString)).thenReturn( + false, // before file is uploaded + true // after file is uploaded + ) + + val s3Storage = new S3Storage(amazonS3Mock, testBucket, scala.concurrent.ExecutionContext.global) + + val filesBefore = Await.result(s3Storage.list(testDirPath).run, 10 seconds) + filesBefore shouldBe empty + + val fileExistsBeforeUpload = Await.result(s3Storage.exists(testFilePath), 10 seconds) + fileExistsBeforeUpload should be(false) + + Await.result(s3Storage.upload(sourceTestFile, testFilePath), 10 seconds) + + val filesAfterUpload = Await.result(s3Storage.list(testDirPath).run, 10 seconds) + filesAfterUpload.size should be(1) + val fileExistsAfterUpload = Await.result(s3Storage.exists(testFilePath), 10 seconds) + fileExistsAfterUpload should be(true) + val uploadedFileLine = filesAfterUpload.head + uploadedFileLine.name should be(Name[File](testFileName)) + uploadedFileLine.location should be(testFilePath) + uploadedFileLine.revision.id.length should be > 0 + uploadedFileLine.lastModificationDate.millis should be > 0L + + val downloadedFile = Await.result(s3Storage.download(testFilePath).run, 10 seconds) + downloadedFile shouldBe defined + downloadedFile.foreach { + _.getAbsolutePath.endsWith(testFilePath.toString) should be(true) + } + + Await.result(s3Storage.delete(testFilePath), 10 seconds) + + val filesAfterRemoval = Await.result(s3Storage.list(testDirPath).run, 10 seconds) + filesAfterRemoval shouldBe empty + } + + "Filesystem files storage" should "create and download local files and do other operations" in { + + val tempDir = System.getProperty("java.io.tmpdir") + val sourceTestFile = generateTestLocalFile(tempDir) + + val randomFolderName = java.util.UUID.randomUUID().toString + val testDirPath = Paths.get(tempDir, randomFolderName) + val testFilePath = Paths.get(tempDir, randomFolderName, "uploadTestFile") + + val fileStorage = new FileSystemStorage(scala.concurrent.ExecutionContext.global) + + val filesBefore = Await.result(fileStorage.list(testDirPath).run, 10 seconds) + filesBefore shouldBe empty + + val fileExistsBeforeUpload = Await.result(fileStorage.exists(testFilePath), 10 seconds) + fileExistsBeforeUpload should be(false) + + Await.result(fileStorage.upload(sourceTestFile, testFilePath), 10 seconds) + + val filesAfterUpload = Await.result(fileStorage.list(testDirPath).run, 10 seconds) + filesAfterUpload.size should be(1) + + val fileExistsAfterUpload = Await.result(fileStorage.exists(testFilePath), 10 seconds) + fileExistsAfterUpload should be(true) + + val uploadedFileLine = filesAfterUpload.head + uploadedFileLine.name should be(Name[File]("uploadTestFile")) + uploadedFileLine.location should be(testFilePath) + uploadedFileLine.revision.id.length should be > 0 + uploadedFileLine.lastModificationDate.millis should be > 0L + + val downloadedFile = Await.result(fileStorage.download(testFilePath).run, 10 seconds) + downloadedFile shouldBe defined + downloadedFile.map(_.getAbsolutePath) should be(Some(testFilePath.toString)) + + Await.result(fileStorage.delete(testFilePath), 10 seconds) + + val filesAfterRemoval = Await.result(fileStorage.list(testDirPath).run, 10 seconds) + filesAfterRemoval shouldBe empty + } + + "Google Cloud Storage" should "upload and download files" in { + import com.google.api.gax.paging.Page + import com.google.cloud.storage.{Blob, Bucket, Storage} + import Bucket.BlobWriteOption + import Storage.BlobListOption + import scala.collection.JavaConverters._ + + val tempDir = System.getProperty("java.io.tmpdir") + val sourceTestFile = generateTestLocalFile(tempDir) + val testFileName = "uploadTestFile" + + val randomFolderName = java.util.UUID.randomUUID().toString + val testDirPath = Paths.get(randomFolderName) + val testFilePath = Paths.get(randomFolderName, testFileName) + + val testBucket = Name[Bucket]("IamBucket") + val gcsMock = mock[Storage] + val pageMock = mock[Page[Blob]] + val bucketMock = mock[Bucket] + val blobMock = mock[Blob] + + when(blobMock.getName).thenReturn(testFileName) + when(blobMock.getGeneration).thenReturn(1000L) + when(blobMock.getUpdateTime).thenReturn(1493422254L) + when(blobMock.getSize).thenReturn(12345L) + when(blobMock.getContent()).thenReturn(Array[Byte](1, 2, 3)) + + val gcsStorage = new GcsStorage(gcsMock, testBucket, scala.concurrent.ExecutionContext.global) + + when(pageMock.iterateAll()).thenReturn( + Iterable[Blob]().asJava, + Iterable[Blob](blobMock).asJava, + Iterable[Blob]().asJava + ) + when(gcsMock.list(testBucket.value, BlobListOption.currentDirectory(), BlobListOption.prefix(s"$testDirPath/"))) + .thenReturn(pageMock) + when(gcsMock.get(testBucket.value, testFilePath.toString)).thenReturn( + null, // before file is uploaded + blobMock // after file is uploaded + ) + + val filesBefore = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesBefore shouldBe empty + + val fileExistsBeforeUpload = Await.result(gcsStorage.exists(testFilePath), 10 seconds) + fileExistsBeforeUpload should be(false) + + when(gcsMock.get(testBucket.value)).thenReturn(bucketMock) + when(gcsMock.get(testBucket.value, testFilePath.toString)).thenReturn(blobMock) + when(bucketMock.create(org.mockito.Matchers.eq(testFileName), any[FileInputStream], any[BlobWriteOption])) + .thenReturn(blobMock) + + Await.result(gcsStorage.upload(sourceTestFile, testFilePath), 10 seconds) + + val filesAfterUpload = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesAfterUpload.size should be(1) + + val fileExistsAfterUpload = Await.result(gcsStorage.exists(testFilePath), 10 seconds) + fileExistsAfterUpload should be(true) + + val downloadedFile = Await.result(gcsStorage.download(testFilePath).run, 10 seconds) + downloadedFile shouldBe defined + downloadedFile.foreach { + _.getAbsolutePath should endWith(testFilePath.toString) + } + + Await.result(gcsStorage.delete(testFilePath), 10 seconds) + + val filesAfterRemoval = Await.result(gcsStorage.list(testDirPath).run, 10 seconds) + filesAfterRemoval shouldBe empty + } + + private def generateTestLocalFile(path: String): File = { + val randomSourceFolderName = java.util.UUID.randomUUID().toString + val sourceTestFile = new File(Paths.get(path, randomSourceFolderName, "uploadTestFile").toString) + sourceTestFile.getParentFile.mkdirs() should be(true) + sourceTestFile.createNewFile() should be(true) + using(new java.io.PrintWriter(sourceTestFile)) { _.append("Test File Contents") } + sourceTestFile + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/jvm/src/test/scala/xyz/driver/core/GeneratorsTest.scala new file mode 100644 index 0000000..7e740a4 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/GeneratorsTest.scala @@ -0,0 +1,264 @@ +package xyz.driver.core + +import org.scalatest.{Assertions, FlatSpec, Matchers} + +import scala.collection.immutable.IndexedSeq + +class GeneratorsTest extends FlatSpec with Matchers with Assertions { + import generators._ + + "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { + + val generatedId1 = nextId[String]() + val generatedId2 = nextId[String]() + val generatedId3 = nextId[Long]() + + generatedId1.length should be >= 0 + generatedId2.length should be >= 0 + generatedId3.length should be >= 0 + generatedId1 should not be generatedId2 + generatedId2 should !==(generatedId3) + } + + it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { + + val generatedLimitedId1 = nextId[String](5) + val generatedLimitedId2 = nextId[String](4) + val generatedLimitedId3 = nextId[Long](3) + + generatedLimitedId1.length should be >= 0 + generatedLimitedId1.length should be < 6 + generatedLimitedId2.length should be >= 0 + generatedLimitedId2.length should be < 5 + generatedLimitedId3.length should be >= 0 + generatedLimitedId3.length should be < 4 + generatedLimitedId1 should not be generatedLimitedId2 + generatedLimitedId2 should !==(generatedLimitedId3) + } + + it should "be able to generate com.drivergrp.core.Name names" in { + + Seq.fill(10)(nextName[String]()).distinct.size should be > 1 + nextName[String]().value.length should be >= 0 + + val fixedLengthName = nextName[String](10) + fixedLengthName.length should be <= 10 + assert(!fixedLengthName.value.exists(_.isControl)) + } + + it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { + + assert(nextNonEmptyName[String]().value.value.nonEmpty) + } + + it should "be able to generate proper UUIDs" in { + + nextUuid() should not be nextUuid() + nextUuid().toString.length should be(36) + } + + it should "be able to generate new Revisions" in { + + nextRevision[String]() should not be nextRevision[String]() + nextRevision[String]().id.length should be > 0 + } + + it should "be able to generate strings" in { + + nextString() should not be nextString() + nextString().length should be >= 0 + + val fixedLengthString = nextString(20) + fixedLengthString.length should be <= 20 + assert(!fixedLengthString.exists(_.isControl)) + } + + it should "be able to generate strings non-empty strings whic are non empty" in { + + assert(nextNonEmptyString().value.nonEmpty) + } + + it should "be able to generate options which are sometimes have values and sometimes not" in { + + val generatedOption = nextOption("2") + + generatedOption should not contain "1" + assert(generatedOption === Some("2") || generatedOption === None) + } + + it should "be able to generate a pair of two generated values" in { + + val constantPair = nextPair("foo", 1L) + constantPair._1 should be("foo") + constantPair._2 should be(1L) + + val generatedPair = nextPair(nextId[Int](), nextName[Int]()) + + generatedPair._1.length should be > 0 + generatedPair._2.length should be > 0 + + nextPair(nextId[Int](), nextName[Int]()) should not be + nextPair(nextId[Int](), nextName[Int]()) + } + + it should "be able to generate a triad of two generated values" in { + + val constantTriad = nextTriad("foo", "bar", 1L) + constantTriad._1 should be("foo") + constantTriad._2 should be("bar") + constantTriad._3 should be(1L) + + val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + + generatedTriad._1.length should be > 0 + generatedTriad._2.length should be > 0 + generatedTriad._3 should be >= BigDecimal(0.00) + + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + } + + it should "be able to generate a time value" in { + + val generatedTime = nextTime() + val currentTime = System.currentTimeMillis() + + generatedTime.millis should be >= 0L + generatedTime.millis should be <= currentTime + } + + it should "be able to generate a time range value" in { + + val generatedTimeRange = nextTimeRange() + val currentTime = System.currentTimeMillis() + + generatedTimeRange.start.millis should be >= 0L + generatedTimeRange.start.millis should be <= currentTime + generatedTimeRange.end.millis should be >= 0L + generatedTimeRange.end.millis should be <= currentTime + generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis + } + + it should "be able to generate a BigDecimal value" in { + + val defaultGeneratedBigDecimal = nextBigDecimal() + + defaultGeneratedBigDecimal should be >= BigDecimal(0.00) + defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) + defaultGeneratedBigDecimal.precision should be(2) + + val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) + + unitIntervalBigDecimal should be >= BigDecimal(0.00) + unitIntervalBigDecimal should be <= BigDecimal(1.00) + unitIntervalBigDecimal.precision should be(8) + } + + it should "be able to generate a specific value from a set of values" in { + + val possibleOptions = Set(1, 3, 5, 123, 0, 9) + + val pick1 = generators.oneOf(possibleOptions) + val pick2 = generators.oneOf(possibleOptions) + val pick3 = generators.oneOf(possibleOptions) + + possibleOptions should contain(pick1) + possibleOptions should contain(pick2) + possibleOptions should contain(pick3) + + val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) + + possibleOptions should contain(pick4) + possibleOptions should contain(pick5) + possibleOptions should contain(pick6) + + Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 + } + + it should "be able to generate a specific value from an enumeratum enum" in { + + import enumeratum._ + sealed trait TestEnumValue extends EnumEntry + object TestEnum extends Enum[TestEnumValue] { + case object Value1 extends TestEnumValue + case object Value2 extends TestEnumValue + case object Value3 extends TestEnumValue + case object Value4 extends TestEnumValue + val values: IndexedSeq[TestEnumValue] = findValues + } + + val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) + + TestEnum.values should contain allElementsOf picks + picks.toSet.size should be >= 1 + } + + it should "be able to generate array with values generated by generators" in { + + val arrayOfTimes = arrayOf(nextTime(), 16) + arrayOfTimes.length should be <= 16 + + val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) + arrayOfBigDecimals.length should be <= 8 + } + + it should "be able to generate seq with values generated by generators" in { + + val seqOfTimes = seqOf(nextTime(), 16) + seqOfTimes.size should be <= 16 + + val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) + seqOfBigDecimals.size should be <= 8 + } + + it should "be able to generate vector with values generated by generators" in { + + val vectorOfTimes = vectorOf(nextTime(), 16) + vectorOfTimes.size should be <= 16 + + val vectorOfStrings = seqOf(nextString(), 8) + vectorOfStrings.size should be <= 8 + } + + it should "be able to generate list with values generated by generators" in { + + val listOfTimes = listOf(nextTime(), 16) + listOfTimes.size should be <= 16 + + val listOfBigDecimals = seqOf(nextBigDecimal(), 8) + listOfBigDecimals.size should be <= 8 + } + + it should "be able to generate set with values generated by generators" in { + + val setOfTimes = vectorOf(nextTime(), 16) + setOfTimes.size should be <= 16 + + val setOfBigDecimals = seqOf(nextBigDecimal(), 8) + setOfBigDecimals.size should be <= 8 + } + + it should "be able to generate maps with keys and values generated by generators" in { + + val generatedConstantMap = mapOf("key", 123, 10) + generatedConstantMap.size should be <= 1 + assert(generatedConstantMap.keys.forall(_ == "key")) + assert(generatedConstantMap.values.forall(_ == 123)) + + val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) + assert(generatedMap.keys.forall(_.length <= 10)) + assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) + } + + it should "compose deeply" in { + + val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) + + generatedNestedMap.size should be <= 10 + generatedNestedMap.keySet.size should be <= 10 + generatedNestedMap.values.size should be <= 10 + assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/JsonTest.scala b/jvm/src/test/scala/xyz/driver/core/JsonTest.scala new file mode 100644 index 0000000..fed2a9d --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/JsonTest.scala @@ -0,0 +1,343 @@ +package xyz.driver.core + +import java.net.InetAddress + +import enumeratum._ +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.numeric.Positive +import eu.timepit.refined.refineMV +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.json._ +import xyz.driver.core.time.provider.SystemTimeProvider +import spray.json._ +import xyz.driver.core.TestTypes.CustomGADT +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.json.enumeratum.HasJsonFormat +import xyz.driver.core.tagging.Taggable +import xyz.driver.core.time.TimeOfDay + +import scala.collection.immutable.IndexedSeq + +class JsonTest extends FlatSpec with Matchers { + import DefaultJsonProtocol._ + + "Json format for Id" should "read and write correct JSON" in { + + val referenceId = Id[String]("1312-34A") + + val writtenJson = json.idFormat.write(referenceId) + writtenJson.prettyPrint should be("\"1312-34A\"") + + val parsedId = json.idFormat.read(writtenJson) + parsedId should be(referenceId) + } + + "Json format for @@" should "read and write correct JSON" in { + trait Irrelevant + val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] + + val format = json.taggedFormat[Id[JsonTest], Irrelevant] + + val writtenJson = format.write(reference) + writtenJson shouldBe JsString("SomeID") + + val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) + parsedId shouldBe reference + } + + "Json format for Name" should "read and write correct JSON" in { + + val referenceName = Name[String]("Homer") + + val writtenJson = json.nameFormat.write(referenceName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedName = json.nameFormat.read(writtenJson) + parsedName should be(referenceName) + } + + "Json format for NonEmptyName" should "read and write correct JSON" in { + + val jsonFormat = json.nonEmptyNameFormat[String] + + val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) + + val writtenJson = jsonFormat.write(referenceNonEmptyName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedNonEmptyName = jsonFormat.read(writtenJson) + parsedNonEmptyName should be(referenceNonEmptyName) + } + + "Json format for Time" should "read and write correct JSON" in { + + val referenceTime = new SystemTimeProvider().currentTime() + + val writtenJson = json.timeFormat.write(referenceTime) + writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") + + val parsedTime = json.timeFormat.read(writtenJson) + parsedTime should be(referenceTime) + } + + "Json format for TimeOfDay" should "read and write correct JSON" in { + val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") + val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") + val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) + writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) + val parsed = json.timeOfDayFormat.read(writtenJson) + parsed should be(referenceTimeOfDay) + } + + "Json format for Date" should "read and write correct JSON" in { + import date._ + + val referenceDate = Date(1941, Month.DECEMBER, 7) + + val writtenJson = json.dateFormat.write(referenceDate) + writtenJson.prettyPrint should be("\"1941-12-07\"") + + val parsedDate = json.dateFormat.read(writtenJson) + parsedDate should be(referenceDate) + } + + "Json format for Revision" should "read and write correct JSON" in { + + val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") + + val writtenJson = json.revisionFormat.write(referenceRevision) + writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") + + val parsedRevision = json.revisionFormat.read(writtenJson) + parsedRevision should be(referenceRevision) + } + + "Json format for Email" should "read and write correct JSON" in { + + val referenceEmail = Email("test", "drivergrp.com") + + val writtenJson = json.emailFormat.write(referenceEmail) + writtenJson should be("\"test@drivergrp.com\"".parseJson) + + val parsedEmail = json.emailFormat.read(writtenJson) + parsedEmail should be(referenceEmail) + } + + "Json format for PhoneNumber" should "read and write correct JSON" in { + + val referencePhoneNumber = PhoneNumber("1", "4243039608") + + val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) + writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) + + val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) + parsedPhoneNumber should be(referencePhoneNumber) + } + + "Json format for ADT mappings" should "read and write correct JSON" in { + + sealed trait EnumVal + case object Val1 extends EnumVal + case object Val2 extends EnumVal + case object Val3 extends EnumVal + + val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) + + val referenceEnumValue1 = Val2 + val referenceEnumValue2 = Val3 + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1.prettyPrint should be("\"b\"") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2.prettyPrint should be("\"c\"") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + } + + "Json format for Enums (external)" should "read and write correct JSON" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val format = new enumeratum.EnumJsonFormat(MyEnum) + + val referenceEnumValue1 = MyEnum.`Val 2` + val referenceEnumValue2 = MyEnum.`Val/3` + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2 shouldBe JsString("Val/3") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 shouldBe referenceEnumValue1 + parsedEnumValue2 shouldBe referenceEnumValue2 + + intercept[DeserializationException] { + format.read(JsString("Val4")) + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + + "Json format for Enums (automatic)" should "read and write correct JSON and not require import" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val referenceEnumValue1: MyEnum = MyEnum.`Val 2` + val referenceEnumValue2: MyEnum = MyEnum.`Val/3` + + val writtenJson1 = referenceEnumValue1.toJson + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = referenceEnumValue2.toJson + writtenJson2 shouldBe JsString("Val/3") + + import spray.json._ + + val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] + val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + + intercept[DeserializationException] { + JsString("Val4").convertTo[MyEnum] + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + + // Should be defined outside of case to have a TypeTag + case class CustomWrapperClass(value: Int) + + "Json format for Value classes" should "read and write correct JSON" in { + + val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) + + val referenceValue1 = CustomWrapperClass(-2) + val referenceValue2 = CustomWrapperClass(10) + + val writtenJson1 = format.write(referenceValue1) + writtenJson1.prettyPrint should be("-2") + + val writtenJson2 = format.write(referenceValue2) + writtenJson2.prettyPrint should be("10") + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + + "Json format for classes GADT" should "read and write correct JSON" in { + + import CustomGADT._ + import DefaultJsonProtocol._ + implicit val case1Format = jsonFormat1(GadtCase1) + implicit val case2Format = jsonFormat1(GadtCase2) + implicit val case3Format = jsonFormat1(GadtCase3) + + val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { + case _: CustomGADT.GadtCase1 => "case1" + case _: CustomGADT.GadtCase2 => "case2" + case _: CustomGADT.GadtCase3 => "case3" + } { + case "case1" => case1Format + case "case2" => case2Format + case "case3" => case3Format + } + + val referenceValue1 = CustomGADT.GadtCase1("4") + val referenceValue2 = CustomGADT.GadtCase2("Hi!") + + val writtenJson1 = format.write(referenceValue1) + writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) + + val writtenJson2 = format.write(referenceValue2) + writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + + "Json format for a Refined value" should "read and write correct JSON" in { + + val jsonFormat = json.refinedJsonFormat[Int, Positive] + + val referenceRefinedNumber = refineMV[Positive](42) + + val writtenJson = jsonFormat.write(referenceRefinedNumber) + writtenJson should be("42".parseJson) + + val parsedRefinedNumber = jsonFormat.read(writtenJson) + parsedRefinedNumber should be(referenceRefinedNumber) + } + + "InetAddress format" should "read and write correct JSON" in { + val address = InetAddress.getByName("127.0.0.1") + val json = inetAddressFormat.write(address) + + json shouldBe JsString("127.0.0.1") + + val parsed = inetAddressFormat.read(json) + parsed shouldBe address + } + + it should "throw a DeserializationException for an invalid IP Address" in { + assertThrows[DeserializationException] { + val invalidAddress = JsString("foobar") + inetAddressFormat.read(invalidAddress) + } + } + + "AuthCredentials format" should "read and write correct JSON" in { + val email = Email("someone", "noehere.com") + val phoneId = PhoneNumber.parse("1 207 8675309") + val password = "nopassword" + + phoneId.isDefined should be(true) // test this real quick + + val emailAuth = AuthCredentials(email.toString, password) + val pnAuth = AuthCredentials(phoneId.get.toString, password) + + val emailWritten = authCredentialsFormat.write(emailAuth) + emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) + + val phoneWritten = authCredentialsFormat.write(pnAuth) + phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) + + val identifierEmailParsed = + authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) + var written = authCredentialsFormat.write(identifierEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + val emailEmailParsed = + authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) + written = authCredentialsFormat.write(emailEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/MessagesTest.scala b/jvm/src/test/scala/xyz/driver/core/MessagesTest.scala new file mode 100644 index 0000000..07b0158 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/MessagesTest.scala @@ -0,0 +1,85 @@ +package xyz.driver.core + +import java.util.Locale + +import com.typesafe.config.{ConfigException, ConfigFactory} +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.messages.Messages +import xyz.driver.core.logging.NoLogger + +import scala.collection.JavaConverters._ + +class MessagesTest extends FlatSpec with Matchers { + + val englishLocaleMessages = + Map("en.greeting" -> "Hello {0}!", "en.greetingFullName" -> "Hello {0} {1} {2}!", "en.hello" -> "Hello world!") + + "Messages" should "read messages from config and format with parameters" in { + + val messagesConfig = ConfigFactory.parseMap(englishLocaleMessages.asJava) + + val messages = Messages.messages(messagesConfig, NoLogger, Locale.US) + + messages("hello") should be("Hello world!") + messages("greeting", "Homer") should be("Hello Homer!") + messages("greetingFullName", "Homer", "J", "Simpson") should be("Hello Homer J Simpson!") + } + + it should "be able to read messages for different locales" in { + + val messagesConfig = ConfigFactory.parseMap( + (englishLocaleMessages ++ Map( + "zh.hello" -> "你好,世界!", + "zh.greeting" -> "你好,{0}!", + "zh.greetingFullName" -> "你好,{0} {1} {2}!" + )).asJava) + + val englishMessages = Messages.messages(messagesConfig, NoLogger, Locale.US) + val englishMessagesToo = Messages.messages(messagesConfig, NoLogger, Locale.ENGLISH) + val chineseMessages = Messages.messages(messagesConfig, NoLogger, Locale.CHINESE) + + englishMessages("hello") should be("Hello world!") + englishMessages("greeting", "Homer") should be("Hello Homer!") + englishMessages("greetingFullName", "Homer", "J", "Simpson") should be("Hello Homer J Simpson!") + + englishMessagesToo("hello") should be(englishMessages("hello")) + englishMessagesToo("greeting", "Homer") should be(englishMessages("greeting", "Homer")) + englishMessagesToo("greetingFullName", "Homer", "J", "Simpson") should be( + englishMessages("greetingFullName", "Homer", "J", "Simpson")) + + chineseMessages("hello") should be("你好,世界!") + chineseMessages("greeting", "Homer") should be("你好,Homer!") + chineseMessages("greetingFullName", "Homer", "J", "Simpson") should be("你好,Homer J Simpson!") + } + + it should "raise exception when locale is not available" in { + + val messagesConfig = ConfigFactory.parseMap(englishLocaleMessages.asJava) + + an[ConfigException.Missing] should be thrownBy + Messages.messages(messagesConfig, NoLogger, Locale.GERMAN) + } + + it should "be able to read nested keys in multiple forms" in { + + val configString = + """ + | en { + | foo.bar = "Foo Bar" + | + | baz { + | boo = "Baz Boo" + | booFormat = "Baz Boo {0}" + | } + | } + """.stripMargin + + val messagesConfig = ConfigFactory.parseString(configString) + + val messages = Messages.messages(messagesConfig, NoLogger, Locale.US) + + messages("foo.bar") should be("Foo Bar") + messages("baz.boo") should be("Baz Boo") + messages("baz.booFormat", "Test") should be("Baz Boo Test") + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/PhoneNumberTest.scala b/jvm/src/test/scala/xyz/driver/core/PhoneNumberTest.scala new file mode 100644 index 0000000..384c7be --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/PhoneNumberTest.scala @@ -0,0 +1,79 @@ +package xyz.driver.core + +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.domain.PhoneNumber + +class PhoneNumberTest extends FlatSpec with Matchers { + + "PhoneNumber.parse" should "recognize US numbers in international format, ignoring non-digits" in { + // format: off + val numbers = List( + "+18005252225", + "+1 800 525 2225", + "+1 (800) 525-2225", + "+1.800.525.2225") + // format: on + + val parsed = numbers.flatMap(PhoneNumber.parse) + + parsed should have size numbers.size + parsed should contain only PhoneNumber("1", "8005252225") + } + + it should "recognize US numbers without the plus sign" in { + PhoneNumber.parse("18005252225") shouldBe Some(PhoneNumber("1", "8005252225")) + } + + it should "recognize US numbers without country code" in { + // format: off + val numbers = List( + "8005252225", + "800 525 2225", + "(800) 525-2225", + "800.525.2225") + // format: on + + val parsed = numbers.flatMap(PhoneNumber.parse) + + parsed should have size numbers.size + parsed should contain only PhoneNumber("1", "8005252225") + } + + it should "recognize CN numbers in international format" in { + PhoneNumber.parse("+868005252225") shouldBe Some(PhoneNumber("86", "8005252225")) + PhoneNumber.parse("+86 134 52 52 2256") shouldBe Some(PhoneNumber("86", "13452522256")) + } + + it should "return None on numbers that are shorter than the minimum number of digits for the country (i.e. US - 10, AR - 11)" in { + withClue("US and CN numbers are 10 digits - 9 digit (and shorter) numbers should not fit") { + // format: off + val numbers = List( + "+1 800 525-222", + "+1 800 525-2", + "+86 800 525-222", + "+86 800 525-2") + // format: on + + numbers.flatMap(PhoneNumber.parse) shouldBe empty + } + + withClue("Argentinian numbers are 11 digits (when prefixed with 0) - 10 digit numbers shouldn't fit") { + // format: off + val numbers = List( + "+54 011 525-22256", + "+54 011 525-2225", + "+54 011 525-222") + // format: on + + numbers.flatMap(PhoneNumber.parse) should contain theSameElementsAs List(PhoneNumber("54", "1152522256")) + } + } + + it should "return None on numbers that are longer than the maximum number of digits for the country (i.e. DK - 8, CN - 11)" in { + val numbers = List("+45 27 45 25 22", "+45 135 525 223", "+86 134 525 22256", "+86 135 525 22256 7") + + numbers.flatMap(PhoneNumber.parse) should contain theSameElementsAs + List(PhoneNumber("45", "27452522"), PhoneNumber("86", "13452522256")) + } + +} diff --git a/jvm/src/test/scala/xyz/driver/core/TestTypes.scala b/jvm/src/test/scala/xyz/driver/core/TestTypes.scala new file mode 100644 index 0000000..bb25deb --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/TestTypes.scala @@ -0,0 +1,14 @@ +package xyz.driver.core + +object TestTypes { + + sealed trait CustomGADT { + val field: String + } + + object CustomGADT { + final case class GadtCase1(field: String) extends CustomGADT + final case class GadtCase2(field: String) extends CustomGADT + final case class GadtCase3(field: String) extends CustomGADT + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/TimeTest.scala b/jvm/src/test/scala/xyz/driver/core/TimeTest.scala new file mode 100644 index 0000000..7a888b6 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/TimeTest.scala @@ -0,0 +1,142 @@ +package xyz.driver.core + +import java.util.TimeZone + +import org.scalacheck.Arbitrary._ +import org.scalacheck.Prop.BooleanOperators +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.prop.Checkers +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.date.Month +import xyz.driver.core.time.{Time, _} + +import scala.concurrent.duration._ + +class TimeTest extends FlatSpec with Matchers with Checkers { + + implicit val arbDuration = Arbitrary[Duration](Gen.chooseNum(0L, 9999999999L).map(_.milliseconds)) + implicit val arbTime = Arbitrary[Time](Gen.chooseNum(0L, 9999999999L).map(millis => Time(millis))) + + "Time" should "have correct methods to compare" in { + + Time(234L).isAfter(Time(123L)) should be(true) + Time(123L).isAfter(Time(123L)) should be(false) + Time(123L).isAfter(Time(234L)) should be(false) + + check((a: Time, b: Time) => (a.millis > b.millis) ==> a.isAfter(b)) + + Time(234L).isBefore(Time(123L)) should be(false) + Time(123L).isBefore(Time(123L)) should be(false) + Time(123L).isBefore(Time(234L)) should be(true) + + check { (a: Time, b: Time) => + (a.millis < b.millis) ==> a.isBefore(b) + } + } + + it should "not modify time" in { + + Time(234L).millis should be(234L) + + check { millis: Long => + Time(millis).millis == millis + } + } + + it should "support arithmetic with scala.concurrent.duration" in { + + Time(123L).advanceBy(0 minutes).millis should be(123L) + Time(123L).advanceBy(1 second).millis should be(123L + Second) + Time(123L).advanceBy(4 days).millis should be(123L + 4 * Days) + + check { (time: Time, duration: Duration) => + time.advanceBy(duration).millis == (time.millis + duration.toMillis) + } + } + + it should "have ordering defined correctly" in { + + Seq(Time(321L), Time(123L), Time(231L)).sorted should + contain theSameElementsInOrderAs Seq(Time(123L), Time(231L), Time(321L)) + + check { times: List[Time] => + times.sorted.sliding(2).filter(_.size == 2).forall { + case Seq(a, b) => + a.millis <= b.millis + } + } + } + + it should "reset to the start of the period, e.g. month" in { + + startOfMonth(Time(1468937089834L)) should be(Time(1467381889834L)) + startOfMonth(Time(1467381889834L)) should be(Time(1467381889834L)) // idempotent + } + + it should "have correct textual representations" in { + import java.util.Locale + import java.util.Locale._ + Locale.setDefault(US) + + textualDate(TimeZone.getTimeZone("EDT"))(Time(1468937089834L)) should be("July 19, 2016") + textualTime(TimeZone.getTimeZone("PDT"))(Time(1468937089834L)) should be("Jul 19, 2016 02:04:49 PM") + } + + "TimeRange" should "have duration defined as a difference of start and end times" in { + + TimeRange(Time(321L), Time(432L)).duration should be(111.milliseconds) + TimeRange(Time(432L), Time(321L)).duration should be((-111).milliseconds) + TimeRange(Time(333L), Time(333L)).duration should be(0.milliseconds) + } + + "Time" should "use TimeZone correctly when converting to Date" in { + + val EST = java.util.TimeZone.getTimeZone("EST") + val PST = java.util.TimeZone.getTimeZone("PST") + + val timestamp = { + import java.util.Calendar + val cal = Calendar.getInstance(EST) + cal.set(Calendar.HOUR_OF_DAY, 1) + Time(cal.getTime().getTime()) + } + + textualDate(EST)(timestamp) should not be textualDate(PST)(timestamp) + timestamp.toDate(EST) should not be timestamp.toDate(PST) + } + + "TimeOfDay" should "be created from valid strings and convert to java.sql.Time" in { + val s = "07:30:45" + val defaultTimeZone = TimeZone.getDefault() + val todFactory = TimeOfDay.parseTimeString(defaultTimeZone)(_) + val tod = todFactory(s) + tod.timeString shouldBe s + tod.timeZoneString shouldBe defaultTimeZone.getID + val sqlTime = tod.toTime + sqlTime.toLocalTime shouldBe tod.localTime + a[java.time.format.DateTimeParseException] should be thrownBy { + val illegal = "7:15" + todFactory(illegal) + } + } + + "TimeOfDay" should "have correct temporal relationships" in { + val s = "07:30:45" + val t = "09:30:45" + val pst = TimeZone.getTimeZone("America/Los_Angeles") + val est = TimeZone.getTimeZone("America/New_York") + val pstTodFactory = TimeOfDay.parseTimeString(pst)(_) + val estTodFactory = TimeOfDay.parseTimeString(est)(_) + val day = 1 + val month = Month.JANUARY + val year = 2018 + val sTodPst = pstTodFactory(s) + val sTodPst2 = pstTodFactory(s) + val tTodPst = pstTodFactory(t) + val tTodEst = estTodFactory(t) + sTodPst.isBefore(tTodPst, day, month, year) shouldBe true + tTodPst.isAfter(sTodPst, day, month, year) shouldBe true + tTodEst.isBefore(sTodPst, day, month, year) shouldBe true + sTodPst.sameTimeAs(sTodPst2, day, month, year) shouldBe true + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/database/DatabaseTest.scala b/jvm/src/test/scala/xyz/driver/core/database/DatabaseTest.scala new file mode 100644 index 0000000..8d2a4ac --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/database/DatabaseTest.scala @@ -0,0 +1,42 @@ +package xyz.driver.core.database + +import org.scalatest.{FlatSpec, Matchers} +import org.scalatest.prop.Checkers +import xyz.driver.core.rest.errors.DatabaseException + +class DatabaseTest extends FlatSpec with Matchers with Checkers { + import xyz.driver.core.generators._ + "Date SQL converter" should "correctly convert back and forth to SQL dates" in { + for (date <- 1 to 100 map (_ => nextDate())) { + sqlDateToDate(dateToSqlDate(date)) should be(date) + } + } + + "Converter helper methods" should "work correctly" in { + object TestConverter extends Converters + + val validLength = nextInt(10) + val valid = nextToken(validLength) + val validOp = Some(valid) + val invalid = nextToken(validLength + nextInt(10, 1)) + val invalidOp = Some(invalid) + def mapper(s: String): Option[String] = if (s.length == validLength) Some(s) else None + + TestConverter.fromStringOrThrow(valid, mapper, valid) should be(valid) + + TestConverter.expectValid(mapper, valid) should be(valid) + + TestConverter.expectExistsAndValid(mapper, validOp) should be(valid) + + TestConverter.expectValidOrEmpty(mapper, validOp) should be(Some(valid)) + TestConverter.expectValidOrEmpty(mapper, None) should be(None) + + an[DatabaseException] should be thrownBy TestConverter.fromStringOrThrow(invalid, mapper, invalid) + + an[DatabaseException] should be thrownBy TestConverter.expectValid(mapper, invalid) + + an[DatabaseException] should be thrownBy TestConverter.expectExistsAndValid(mapper, invalidOp) + + an[DatabaseException] should be thrownBy TestConverter.expectValidOrEmpty(mapper, invalidOp) + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/jvm/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..eda6a8c --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,84 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{HttpMethod, StatusCodes} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.ConfigFactory +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, SimpleModule} + +class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { + val config = ConfigFactory.parseString(""" + |application { + | cors { + | allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"] + | allowedOrigins: [{scheme: https, hostSuffix: example.com}] + | } + |} + """.stripMargin).withFallback(ConfigFactory.load) + + val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) + val allowedMethods: collection.immutable.Seq[HttpMethod] = { + import akka.http.scaladsl.model.HttpMethods._ + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS) + } + + import scala.reflect.runtime.universe.typeOf + class TestApp(testRoute: Route) + extends DriverApp( + appName = "test-app", + version = "0.0.1", + gitHash = "deadb33f", + modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), + config = config, + log = xyz.driver.core.logging.NoLogger + ) + + it should "respond with the correct CORS headers for the swagger OPTIONS route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Options(s"/api-docs/swagger.json") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for the test route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for a concatenated route" in { + val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) + Post(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "allow subdomains of allowed origin suffixes" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with default domains for invalid origins" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/jvm/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala new file mode 100644 index 0000000..d32fefd --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala @@ -0,0 +1,123 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.Connection +import akka.http.scaladsl.server.Directives.{complete => akkaComplete} +import akka.http.scaladsl.server.{Directives, Rejection, RejectionHandler, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.scalalogging.Logger +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.logging.NoLogger +import xyz.driver.core.json.serviceExceptionFormat +import xyz.driver.core.FutureExtensions +import xyz.driver.core.rest.errors._ + +import scala.collection.immutable +import scala.concurrent.Future + +class DriverRouteTest + extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { + class TestRoute(override val route: Route) extends DriverRoute { + override def log: Logger = NoLogger + } + + "DriverRoute" should "respond with 200 OK for a basic route" in { + val route = new TestRoute(akkaComplete(StatusCodes.OK)) + + Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.OK + } + } + + it should "respond with a 401 for an InvalidInputException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe InvalidInputException() + } + } + + it should "respond with a 403 for InvalidActionException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.Forbidden + responseAs[ServiceException] shouldBe InvalidActionException() + } + } + + it should "respond with a 404 for ResourceNotFoundException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + responseAs[ServiceException] shouldBe ResourceNotFoundException() + } + } + + it should "respond with a 500 for ExternalServiceException" in { + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe error + } + } + + it should "allow pass-through of external service exceptions" in { + val innerError = InvalidInputException() + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) + val future = Future.failed[String](error) + val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe innerError + } + } + + it should "respond with a 503 for ExternalServiceTimeoutException" in { + val error = ExternalServiceTimeoutException("GET /api/v1/users/") + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.GatewayTimeout + responseAs[ServiceException] shouldBe error + } + } + + it should "respond with a 500 for DatabaseException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe DatabaseException() + } + } + + it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { + val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() + val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) + + Get("/foo") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.OK + headers should contain(Connection("close")) + } + + Get("/bar") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.NotFound + headers should contain(Connection("close")) + } + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/jvm/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..987717d --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,101 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "handle optional values correctly when old value is null" in { + val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/jvm/src/test/scala/xyz/driver/core/rest/RestTest.scala b/jvm/src/test/scala/xyz/driver/core/rest/RestTest.scala new file mode 100644 index 0000000..68fe419 --- /dev/null +++ b/jvm/src/test/scala/xyz/driver/core/rest/RestTest.scala @@ -0,0 +1,73 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.util.ByteString +import org.scalatest.{Matchers, WordSpec} +import xyz.driver.core.rest + +class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { + "`escapeScriptTags` function" should { + "escape script tags properly" in { + val dirtyString = "</sc----</sc----</sc" + val cleanString = "--------------------" + + (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc")) + + (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString) + } + } + + "paginated directive" should { + val route: Route = rest.paginated { paginated => + complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "provide a default pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,100") + } + } + "provide default values for a partial pagination" in { + Get("/?pageSize=2") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,2") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=-1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "optional paginated directive" should { + val route: Route = rest.optionalPagination { paginated => + complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "without pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "no pagination") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } +} |