From ce099ac3ba85f71adeba0bb8398b69dc7cd2e8d1 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Fri, 16 Mar 2018 18:47:30 +0100 Subject: Add PatchSupport trait and tests --- .../xyz/driver/core/rest/PatchSupportTest.scala | 77 ++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala (limited to 'src/test/scala/xyz/driver') diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala new file mode 100644 index 0000000..dcb3a93 --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -0,0 +1,77 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypes, HttpEntity, RequestEntity, StatusCodes} +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchSupportTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchSupport { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(implicit patchRetrievable: PatchRetrievable[Foo]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + patch(as[Foo], fooId) { patchedFoo => + complete(patchedFoo) + } + }) + + def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + implicit val fooPatchable: PatchRetrievable[Foo] = _ => Future.successful(Option.empty[Foo]) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "return a 400 for nulls on non-optional values" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } +} -- cgit v1.2.3 From cdcf028d96f5dea894ea31e1ab8cf0b6575bc11c Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Fri, 16 Mar 2018 20:26:28 +0100 Subject: Add implicit ServiceRequestContext to PatchRetrievable --- .../scala/xyz/driver/core/rest/PatchSupport.scala | 42 +++++++++++++--------- .../xyz/driver/core/rest/PatchSupportTest.scala | 10 +++--- 2 files changed, 31 insertions(+), 21 deletions(-) (limited to 'src/test/scala/xyz/driver') diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala index 5fd9149..9a28de1 100644 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -11,7 +11,14 @@ import scala.concurrent.Future trait PatchSupport extends Directives with SprayJsonSupport { trait PatchRetrievable[T] { - def apply(id: Id[T]): Future[Option[T]] + def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] + } + + object PatchRetrievable { + def apply[T](retriever: ((Id[T], ServiceRequestContext) => Future[Option[T]])): PatchRetrievable[T] = + new PatchRetrievable[T] { + override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id, ctx) + } } protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { @@ -57,21 +64,24 @@ trait PatchSupport extends Directives with SprayJsonSupport { implicit val jsonFormat: RootJsonFormat[T] = patchable._2 Directives.patch { entity(as[JsValue]) { newValue => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - util - .Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - util.Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => util.Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() + serviceContext { implicit ctx => + onSuccess(retriever(id).map(_.map(_.toJson))) { + case Some(oldValue) => + val mergedObj = mergeJsValues(oldValue, newValue) + util + .Try(mergedObj.convertTo[T]) + .transform[Route]( + mergedT => util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + util.Success( + reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => util.Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => + reject() + } } } }(requestCtx) diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala index dcb3a93..469ac84 100644 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -31,7 +31,7 @@ class PatchSupportTest def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) "PatchSupport" should "allow partial updates to an existing object" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { handled shouldBe true @@ -40,7 +40,7 @@ class PatchSupportTest } it should "merge deeply nested objects" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { handled shouldBe true @@ -49,7 +49,7 @@ class PatchSupportTest } it should "return a 404 if the object is not found" in { - implicit val fooPatchable: PatchRetrievable[Foo] = _ => Future.successful(Option.empty[Foo]) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(None)) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { handled shouldBe true @@ -58,7 +58,7 @@ class PatchSupportTest } it should "handle nulls on optional values correctly" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { handled shouldBe true @@ -67,7 +67,7 @@ class PatchSupportTest } it should "return a 400 for nulls on non-optional values" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { handled shouldBe true -- cgit v1.2.3 From dd25ecb1e5cc93bd2da94f4e4bfddc9d3a5ebb5e Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Tue, 20 Mar 2018 14:56:02 -0700 Subject: Enforce application/merge-patch+json Content-Type --- .../scala/xyz/driver/core/rest/PatchSupport.scala | 54 ++++++++++++++-------- .../xyz/driver/core/rest/PatchSupportTest.scala | 25 +++++++--- 2 files changed, 54 insertions(+), 25 deletions(-) (limited to 'src/test/scala/xyz/driver') diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala index 9a28de1..5ded61f 100644 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -2,13 +2,17 @@ package xyz.driver.core.rest import akka.http.javadsl.server.Rejections import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.server.{Directive, Directive1, Directives, Route} +import akka.http.scaladsl.server._ import spray.json._ import xyz.driver.core.Id import scala.concurrent.Future +import scala.util.{Failure, Success, Try} +import scalaz.syntax.equal._ +import scalaz.Scalaz.stringInstance trait PatchSupport extends Directives with SprayJsonSupport { + protected val MergePatchPlusJson: String = "merge-patch+json" trait PatchRetrievable[T] { def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] @@ -52,6 +56,18 @@ trait PatchSupport extends Directives with SprayJsonSupport { } } + def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx => + val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`] + val isCorrectContentType = + contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson) + if (!isCorrectContentType) { + reject( + Rejections.malformedRequestContent( + s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests", + new RuntimeException))(requestCtx) + } else inner(())(requestCtx) + } + def as[T]( implicit patchable: PatchRetrievable[T], jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) = @@ -63,24 +79,24 @@ trait PatchSupport extends Directives with SprayJsonSupport { val retriever = patchable._1 implicit val jsonFormat: RootJsonFormat[T] = patchable._2 Directives.patch { - entity(as[JsValue]) { newValue => - serviceContext { implicit ctx => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - util - .Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - util.Success( - reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => util.Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() + rejectNonMergePatchContentType { + entity(as[JsValue]) { newValue => + serviceContext { implicit ctx => + onSuccess(retriever(id).map(_.map(_.toJson))) { + case Some(oldValue) => + val mergedObj = mergeJsValues(oldValue, newValue) + Try(mergedObj.convertTo[T]) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => + reject() + } } } } diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala index 469ac84..20d667d 100644 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -1,7 +1,8 @@ package xyz.driver.core.rest import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.{ContentTypes, HttpEntity, RequestEntity, StatusCodes} +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` import akka.http.scaladsl.server.{Directives, Route} import akka.http.scaladsl.testkit.ScalatestRouteTest import org.scalatest.{FlatSpec, Matchers} @@ -30,10 +31,12 @@ class PatchSupportTest def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) + val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get) + "PatchSupport" should "allow partial updates to an existing object" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(rank = 4) } @@ -42,7 +45,8 @@ class PatchSupportTest it should "merge deeply nested objects" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) + .withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) } @@ -51,7 +55,7 @@ class PatchSupportTest it should "return a 404 if the object is not found" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(None)) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true status shouldBe StatusCodes.NotFound } @@ -60,7 +64,7 @@ class PatchSupportTest it should "handle nulls on optional values correctly" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(bar = None) } @@ -69,9 +73,18 @@ class PatchSupportTest it should "return a 400 for nulls on non-optional values" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true status shouldBe StatusCodes.BadRequest } } + + it should "return a 400 for incorrect Content-Type" in { + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + status shouldBe StatusCodes.BadRequest + responseAs[String] should include("application/merge-patch+json") + } + } } -- cgit v1.2.3 From 424e025f1c719006fe6a6669e43667e1d39ff076 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Tue, 20 Mar 2018 14:59:01 -0700 Subject: Curry the PatchRetrievable apply method --- src/main/scala/xyz/driver/core/rest/PatchSupport.scala | 4 ++-- src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'src/test/scala/xyz/driver') diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala index 5ded61f..d7c77c8 100644 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -19,9 +19,9 @@ trait PatchSupport extends Directives with SprayJsonSupport { } object PatchRetrievable { - def apply[T](retriever: ((Id[T], ServiceRequestContext) => Future[Option[T]])): PatchRetrievable[T] = + def apply[T](retriever: (Id[T] => (ServiceRequestContext => Future[Option[T]]))): PatchRetrievable[T] = new PatchRetrievable[T] { - override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id, ctx) + override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id)(ctx) } } diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala index 20d667d..5c7faf8 100644 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -34,7 +34,7 @@ class PatchSupportTest val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get) "PatchSupport" should "allow partial updates to an existing object" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true @@ -43,7 +43,7 @@ class PatchSupportTest } it should "merge deeply nested objects" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) .withHeaders(ContentTypeHeader) ~> route ~> check { @@ -53,7 +53,7 @@ class PatchSupportTest } it should "return a 404 if the object is not found" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(None)) + implicit val fooPatchable = PatchRetrievable[Foo](_ => _ => Future.successful(None)) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true @@ -62,7 +62,7 @@ class PatchSupportTest } it should "handle nulls on optional values correctly" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true @@ -71,7 +71,7 @@ class PatchSupportTest } it should "return a 400 for nulls on non-optional values" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true @@ -80,7 +80,7 @@ class PatchSupportTest } it should "return a 400 for incorrect Content-Type" in { - implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { status shouldBe StatusCodes.BadRequest -- cgit v1.2.3 From 33e491adc58b3ee3a37194339f09afa70d42a0e2 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Sun, 25 Mar 2018 15:29:38 -0700 Subject: Use patch unmarshaller --- .../xyz/driver/core/rest/PatchDirectives.scala | 104 ++++++++++++++++++++ .../scala/xyz/driver/core/rest/PatchSupport.scala | 107 --------------------- .../xyz/driver/core/rest/PatchDirectivesTest.scala | 92 ++++++++++++++++++ .../xyz/driver/core/rest/PatchSupportTest.scala | 90 ----------------- 4 files changed, 196 insertions(+), 197 deletions(-) create mode 100644 src/main/scala/xyz/driver/core/rest/PatchDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PatchSupport.scala create mode 100644 src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala (limited to 'src/test/scala/xyz/driver') diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..256358c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject(oldObj.fields.map({ + case (key, oldValue) => + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala deleted file mode 100644 index d7c77c8..0000000 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ /dev/null @@ -1,107 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.javadsl.server.Rejections -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.server._ -import spray.json._ -import xyz.driver.core.Id - -import scala.concurrent.Future -import scala.util.{Failure, Success, Try} -import scalaz.syntax.equal._ -import scalaz.Scalaz.stringInstance - -trait PatchSupport extends Directives with SprayJsonSupport { - protected val MergePatchPlusJson: String = "merge-patch+json" - - trait PatchRetrievable[T] { - def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] - } - - object PatchRetrievable { - def apply[T](retriever: (Id[T] => (ServiceRequestContext => Future[Option[T]]))): PatchRetrievable[T] = - new PatchRetrievable[T] { - override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id)(ctx) - } - } - - protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { - JsObject(oldObj.fields.map({ - case (key, oldValue) => - val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) - key -> newValue - })(collection.breakOut): _*) - } - - protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { - def mergeError(typ: String): Nothing = - deserializationError(s"Expected $typ value, got $newValue") - - if (maxLevels.exists(_ < 0)) oldValue - else { - (oldValue, newValue) match { - case (_: JsString, newString @ (JsString(_) | JsNull)) => newString - case (_: JsString, _) => mergeError("string") - case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber - case (_: JsNumber, _) => mergeError("number") - case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool - case (_: JsBoolean, _) => mergeError("boolean") - case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr - case (_: JsArray, _) => mergeError("array") - case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) - case (_: JsObject, JsNull) => JsNull - case (_: JsObject, _) => mergeError("object") - case (JsNull, _) => newValue - } - } - } - - def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx => - val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`] - val isCorrectContentType = - contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson) - if (!isCorrectContentType) { - reject( - Rejections.malformedRequestContent( - s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests", - new RuntimeException))(requestCtx) - } else inner(())(requestCtx) - } - - def as[T]( - implicit patchable: PatchRetrievable[T], - jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) = - (patchable, jsonFormat) - - def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive { - inner => requestCtx => - import requestCtx.executionContext - val retriever = patchable._1 - implicit val jsonFormat: RootJsonFormat[T] = patchable._2 - Directives.patch { - rejectNonMergePatchContentType { - entity(as[JsValue]) { newValue => - serviceContext { implicit ctx => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => scala.util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() - } - } - } - } - }(requestCtx) - } -} - -object PatchSupport extends PatchSupport diff --git a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..6a6b035 --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,92 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala deleted file mode 100644 index 5c7faf8..0000000 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ /dev/null @@ -1,90 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.`Content-Type` -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import spray.json._ -import xyz.driver.core.{Id, Name} -import xyz.driver.core.json._ - -import scala.concurrent.Future - -class PatchSupportTest - extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol - with Directives with PatchSupport { - case class Bar(name: Name[Bar], size: Int) - case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) - implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) - implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) - - val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) - - def route(implicit patchRetrievable: PatchRetrievable[Foo]): Route = - Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => - patch(as[Foo], fooId) { patchedFoo => - complete(patchedFoo) - } - }) - - def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) - - val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get) - - "PatchSupport" should "allow partial updates to an existing object" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4) - } - } - - it should "merge deeply nested objects" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) - .withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 404 if the object is not found" in { - implicit val fooPatchable = PatchRetrievable[Foo](_ => _ => Future.successful(None)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - } - } - - it should "handle nulls on optional values correctly" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = None) - } - } - - it should "return a 400 for nulls on non-optional values" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - } - } - - it should "return a 400 for incorrect Content-Type" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { - status shouldBe StatusCodes.BadRequest - responseAs[String] should include("application/merge-patch+json") - } - } -} -- cgit v1.2.3 From fc6ecfe212c84271a3454617054aaf25890e886a Mon Sep 17 00:00:00 2001 From: Arthur Rand Date: Wed, 28 Mar 2018 05:56:21 -0700 Subject: [API-1468] add TimeOfDay (#141) * add TimeOfDay * add formatter * . * Revert "." This reverts commit 89576de98092dd75d3af7d82d244d5eaa24d31d9. * scalafmt * add before and after to ToD, and tests * rearrage, make fromStrings * add generator * address comments * use explicit string for TimeZoneId * renaming * revert Converters changes * change name of private method * change apply method * use month --- src/main/scala/xyz/driver/core/generators.scala | 4 +- src/main/scala/xyz/driver/core/json.scala | 31 +++++++- src/main/scala/xyz/driver/core/time.scala | 87 ++++++++++++++++++++++ src/test/scala/xyz/driver/core/JsonTest.scala | 10 +++ src/test/scala/xyz/driver/core/TimeTest.scala | 36 +++++++++ .../xyz/driver/core/database/DatabaseTest.scala | 1 - 6 files changed, 165 insertions(+), 4 deletions(-) (limited to 'src/test/scala/xyz/driver') diff --git a/src/main/scala/xyz/driver/core/generators.scala b/src/main/scala/xyz/driver/core/generators.scala index e3ff326..143044c 100644 --- a/src/main/scala/xyz/driver/core/generators.scala +++ b/src/main/scala/xyz/driver/core/generators.scala @@ -3,7 +3,7 @@ package xyz.driver.core import java.math.MathContext import java.util.UUID -import xyz.driver.core.time.{Time, TimeRange} +import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} import xyz.driver.core.date.{Date, DayOfWeek} import scala.reflect.ClassTag @@ -69,6 +69,8 @@ object generators { def nextTime(): Time = Time(math.abs(nextLong() % System.currentTimeMillis)) + def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) + def nextTimeRange(): TimeRange = { val oneTime = nextTime() val anotherTime = nextTime() diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala index 02a35fd..4d7fa04 100644 --- a/src/main/scala/xyz/driver/core/json.scala +++ b/src/main/scala/xyz/driver/core/json.scala @@ -1,7 +1,7 @@ package xyz.driver.core import java.net.InetAddress -import java.util.UUID +import java.util.{TimeZone, UUID} import scala.reflect.runtime.universe._ import scala.util.Try @@ -14,7 +14,7 @@ import spray.json._ import xyz.driver.core.auth.AuthCredentials import xyz.driver.core.date.{Date, DayOfWeek, Month} import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.time.Time +import xyz.driver.core.time.{Time, TimeOfDay} import eu.timepit.refined.refineV import eu.timepit.refined.api.{Refined, Validate} import eu.timepit.refined.collection.NonEmpty @@ -80,6 +80,33 @@ object json { } } + implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { + private val formatter = TimeOfDay.getFormatter + def read(json: JsValue): java.time.LocalTime = json match { + case JsString(chars) => + java.time.LocalTime.parse(chars) + case _ => deserializationError(s"Expected time string got ${json.toString}") + } + + def write(obj: java.time.LocalTime): JsValue = { + JsString(obj.format(formatter)) + } + } + + implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { + override def write(obj: TimeZone): JsValue = { + JsString(obj.getID()) + } + + override def read(json: JsValue): TimeZone = json match { + case JsString(chars) => + java.util.TimeZone.getTimeZone(chars) + case _ => deserializationError(s"Expected time zone string got ${json.toString}") + } + } + + implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) + implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new EnumJsonFormat[DayOfWeek](DayOfWeek.All.map(w => w.toString -> w)(collection.breakOut): _*) diff --git a/src/main/scala/xyz/driver/core/time.scala b/src/main/scala/xyz/driver/core/time.scala index 3bcc7bc..bab304d 100644 --- a/src/main/scala/xyz/driver/core/time.scala +++ b/src/main/scala/xyz/driver/core/time.scala @@ -4,7 +4,10 @@ import java.text.SimpleDateFormat import java.util._ import java.util.concurrent.TimeUnit +import xyz.driver.core.date.Month + import scala.concurrent.duration._ +import scala.util.Try object time { @@ -39,6 +42,90 @@ object time { } } + /** + * Encapsulates a time and timezone without a specific date. + */ + final case class TimeOfDay(localTime: java.time.LocalTime, timeZone: TimeZone) { + + /** + * Is this time before another time on a specific day. Day light savings safe. These are zero-indexed + * for month/day. + */ + def isBefore(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).before(other.toCalendar(day, month, year)) + } + + /** + * Is this time after another time on a specific day. Day light savings safe. + */ + def isAfter(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).after(other.toCalendar(day, month, year)) + } + + def sameTimeAs(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).equals(other.toCalendar(day, month, year)) + } + + /** + * Enforces the same formatting as expected by [[java.sql.Time]] + * @return string formatted for `java.sql.Time` + */ + def timeString: String = { + localTime.format(TimeOfDay.getFormatter) + } + + /** + * @return a string parsable by [[java.util.TimeZone]] + */ + def timeZoneString: String = { + timeZone.getID + } + + /** + * @return this [[TimeOfDay]] as [[java.sql.Time]] object, [[java.sql.Time.valueOf]] will + * throw when the string is not valid, but this is protected by [[timeString]] method. + */ + def toTime: java.sql.Time = { + java.sql.Time.valueOf(timeString) + } + + private def toCalendar(day: Int, month: Int, year: Int): Calendar = { + val cal = Calendar.getInstance(timeZone) + cal.set(year, month, day, localTime.getHour, localTime.getMinute, localTime.getSecond) + cal + } + } + + object TimeOfDay { + def now(): TimeOfDay = { + TimeOfDay(java.time.LocalTime.now(), TimeZone.getDefault) + } + + /** + * Throws when [s] is not parsable by [[java.time.LocalTime.parse]], uses default [[java.util.TimeZone]] + */ + def parseTimeString(tz: TimeZone = TimeZone.getDefault)(s: String): TimeOfDay = { + TimeOfDay(java.time.LocalTime.parse(s), tz) + } + + def fromString(tz: TimeZone)(s: String): Option[TimeOfDay] = { + val op = Try(java.time.LocalTime.parse(s)).toOption + op.map(lt => TimeOfDay(lt, tz)) + } + + def fromStrings(zoneId: String)(s: String): Option[TimeOfDay] = { + val op = Try(TimeZone.getTimeZone(zoneId)).toOption + op.map(tz => TimeOfDay.parseTimeString(tz)(s)) + } + + /** + * Formatter that enforces `HH:mm:ss` which is expected by [[java.sql.Time]] + */ + def getFormatter: java.time.format.DateTimeFormatter = { + java.time.format.DateTimeFormatter.ofPattern("HH:mm:ss") + } + } + object Time { implicit def timeOrdering: Ordering[Time] = Ordering.by(_.millis) diff --git a/src/test/scala/xyz/driver/core/JsonTest.scala b/src/test/scala/xyz/driver/core/JsonTest.scala index a45025a..827624c 100644 --- a/src/test/scala/xyz/driver/core/JsonTest.scala +++ b/src/test/scala/xyz/driver/core/JsonTest.scala @@ -11,6 +11,7 @@ import xyz.driver.core.time.provider.SystemTimeProvider import spray.json._ import xyz.driver.core.TestTypes.CustomGADT import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.time.TimeOfDay class JsonTest extends FlatSpec with Matchers { import DefaultJsonProtocol._ @@ -61,6 +62,15 @@ class JsonTest extends FlatSpec with Matchers { parsedTime should be(referenceTime) } + "Json format for TimeOfDay" should "read and write correct JSON" in { + val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") + val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") + val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) + writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) + val parsed = json.timeOfDayFormat.read(writtenJson) + parsed should be(referenceTimeOfDay) + } + "Json format for Date" should "read and write correct JSON" in { import date._ diff --git a/src/test/scala/xyz/driver/core/TimeTest.scala b/src/test/scala/xyz/driver/core/TimeTest.scala index b83137c..b72fde8 100644 --- a/src/test/scala/xyz/driver/core/TimeTest.scala +++ b/src/test/scala/xyz/driver/core/TimeTest.scala @@ -7,6 +7,7 @@ import org.scalacheck.Prop.BooleanOperators import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.Checkers import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.date.Month import xyz.driver.core.time.{Time, _} import scala.concurrent.duration._ @@ -100,4 +101,39 @@ class TimeTest extends FlatSpec with Matchers with Checkers { textualDate(EST)(timestamp) should not be textualDate(PST)(timestamp) timestamp.toDate(EST) should not be timestamp.toDate(PST) } + + "TimeOfDay" should "be created from valid strings and convert to java.sql.Time" in { + val s = "07:30:45" + val defaultTimeZone = TimeZone.getDefault() + val todFactory = TimeOfDay.parseTimeString(defaultTimeZone)(_) + val tod = todFactory(s) + tod.timeString shouldBe s + tod.timeZoneString shouldBe defaultTimeZone.getID + val sqlTime = tod.toTime + sqlTime.toLocalTime shouldBe tod.localTime + a[java.time.format.DateTimeParseException] should be thrownBy { + val illegal = "7:15" + todFactory(illegal) + } + } + + "TimeOfDay" should "have correct temporal relationships" in { + val s = "07:30:45" + val t = "09:30:45" + val pst = TimeZone.getTimeZone("America/Los_Angeles") + val est = TimeZone.getTimeZone("America/New_York") + val pstTodFactory = TimeOfDay.parseTimeString(pst)(_) + val estTodFactory = TimeOfDay.parseTimeString(est)(_) + val day = 1 + val month = Month.JANUARY + val year = 2018 + val sTodPst = pstTodFactory(s) + val sTodPst2 = pstTodFactory(s) + val tTodPst = pstTodFactory(t) + val tTodEst = estTodFactory(t) + sTodPst.isBefore(tTodPst, day, month, year) shouldBe true + tTodPst.isAfter(sTodPst, day, month, year) shouldBe true + tTodEst.isBefore(sTodPst, day, month, year) shouldBe true + sTodPst.sameTimeAs(sTodPst2, day, month, year) shouldBe true + } } diff --git a/src/test/scala/xyz/driver/core/database/DatabaseTest.scala b/src/test/scala/xyz/driver/core/database/DatabaseTest.scala index f85dcad..8d2a4ac 100644 --- a/src/test/scala/xyz/driver/core/database/DatabaseTest.scala +++ b/src/test/scala/xyz/driver/core/database/DatabaseTest.scala @@ -39,5 +39,4 @@ class DatabaseTest extends FlatSpec with Matchers with Checkers { an[DatabaseException] should be thrownBy TestConverter.expectValidOrEmpty(mapper, invalidOp) } - } -- cgit v1.2.3 From d35332b7e67d6ae6bea3fd50b9405b554a18b491 Mon Sep 17 00:00:00 2001 From: Sergey Nastich Date: Wed, 28 Mar 2018 13:48:02 -0400 Subject: SCALA-20 Use liphonenumber in `PhoneNumber.parse` to accomodate chinese numbers (and other countries) --- build.sbt | 37 +++++----- src/main/scala/xyz/driver/core/domain.scala | 18 +++-- .../scala/xyz/driver/core/PhoneNumberTest.scala | 79 ++++++++++++++++++++++ 3 files changed, 106 insertions(+), 28 deletions(-) create mode 100644 src/test/scala/xyz/driver/core/PhoneNumberTest.scala (limited to 'src/test/scala/xyz/driver') diff --git a/build.sbt b/build.sbt index 9f878f1..88e4582 100644 --- a/build.sbt +++ b/build.sbt @@ -7,22 +7,23 @@ lazy val core = (project in file(".")) .driverLibrary("core") .settings(lintingSettings ++ formatSettings) .settings(libraryDependencies ++= Seq( - "xyz.driver" %% "tracing" % "0.0.2", - "com.typesafe.akka" %% "akka-http-core" % akkaHttpV, - "com.typesafe.akka" %% "akka-http-spray-json" % akkaHttpV, - "com.typesafe.akka" %% "akka-http-testkit" % akkaHttpV, - "com.pauldijou" %% "jwt-core" % "0.14.0", - "org.scalatest" %% "scalatest" % "3.0.2" % "test", - "org.scalacheck" %% "scalacheck" % "1.13.4" % "test", - "org.scalaz" %% "scalaz-core" % "7.2.19", - "org.mockito" % "mockito-core" % "1.9.5" % "test", - "com.github.swagger-akka-http" %% "swagger-akka-http" % "0.11.2", - "com.amazonaws" % "aws-java-sdk-s3" % "1.11.26", - "com.google.cloud" % "google-cloud-pubsub" % "0.25.0-beta", - "com.google.cloud" % "google-cloud-storage" % "1.7.0", - "com.typesafe.slick" %% "slick" % "3.2.1", - "com.typesafe" % "config" % "1.3.1", - "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0", - "eu.timepit" %% "refined" % "0.8.4", - "ch.qos.logback" % "logback-classic" % "1.1.11" + "xyz.driver" %% "tracing" % "0.0.2", + "com.typesafe.akka" %% "akka-http-core" % akkaHttpV, + "com.typesafe.akka" %% "akka-http-spray-json" % akkaHttpV, + "com.typesafe.akka" %% "akka-http-testkit" % akkaHttpV, + "com.pauldijou" %% "jwt-core" % "0.14.0", + "org.scalatest" %% "scalatest" % "3.0.2" % "test", + "org.scalacheck" %% "scalacheck" % "1.13.4" % "test", + "org.scalaz" %% "scalaz-core" % "7.2.19", + "com.github.swagger-akka-http" %% "swagger-akka-http" % "0.11.2", + "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0", + "eu.timepit" %% "refined" % "0.8.4", + "com.typesafe.slick" %% "slick" % "3.2.1", + "org.mockito" % "mockito-core" % "1.9.5" % "test", + "com.amazonaws" % "aws-java-sdk-s3" % "1.11.26", + "com.google.cloud" % "google-cloud-pubsub" % "0.25.0-beta", + "com.google.cloud" % "google-cloud-storage" % "1.7.0", + "com.typesafe" % "config" % "1.3.1", + "ch.qos.logback" % "logback-classic" % "1.1.11", + "com.googlecode.libphonenumber" % "libphonenumber" % "8.9.2" )) diff --git a/src/main/scala/xyz/driver/core/domain.scala b/src/main/scala/xyz/driver/core/domain.scala index 48943a7..7731345 100644 --- a/src/main/scala/xyz/driver/core/domain.scala +++ b/src/main/scala/xyz/driver/core/domain.scala @@ -1,13 +1,14 @@ package xyz.driver.core +import com.google.i18n.phonenumbers.PhoneNumberUtil import scalaz.Equal -import scalaz.syntax.equal._ import scalaz.std.string._ +import scalaz.syntax.equal._ object domain { final case class Email(username: String, domain: String) { - override def toString = username + "@" + domain + override def toString: String = username + "@" + domain } object Email { @@ -27,16 +28,13 @@ object domain { } object PhoneNumber { - def parse(phoneNumberString: String): Option[PhoneNumber] = { - val onlyDigits = phoneNumberString.replaceAll("[^\\d.]", "") - if (onlyDigits.length < 10) None - else { - val tenDigitNumber = onlyDigits.takeRight(10) - val countryCode = Option(onlyDigits.dropRight(10)).filter(_.nonEmpty).getOrElse("1") + private val phoneUtil = PhoneNumberUtil.getInstance() - Some(PhoneNumber(countryCode, tenDigitNumber)) - } + def parse(phoneNumber: String): Option[PhoneNumber] = { + val phone = phoneUtil.parseAndKeepRawInput(phoneNumber, "US") + if (!phoneUtil.isValidNumber(phone)) None + else Some(PhoneNumber(phone.getCountryCode.toString, phone.getNationalNumber.toString)) } } } diff --git a/src/test/scala/xyz/driver/core/PhoneNumberTest.scala b/src/test/scala/xyz/driver/core/PhoneNumberTest.scala new file mode 100644 index 0000000..384c7be --- /dev/null +++ b/src/test/scala/xyz/driver/core/PhoneNumberTest.scala @@ -0,0 +1,79 @@ +package xyz.driver.core + +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.domain.PhoneNumber + +class PhoneNumberTest extends FlatSpec with Matchers { + + "PhoneNumber.parse" should "recognize US numbers in international format, ignoring non-digits" in { + // format: off + val numbers = List( + "+18005252225", + "+1 800 525 2225", + "+1 (800) 525-2225", + "+1.800.525.2225") + // format: on + + val parsed = numbers.flatMap(PhoneNumber.parse) + + parsed should have size numbers.size + parsed should contain only PhoneNumber("1", "8005252225") + } + + it should "recognize US numbers without the plus sign" in { + PhoneNumber.parse("18005252225") shouldBe Some(PhoneNumber("1", "8005252225")) + } + + it should "recognize US numbers without country code" in { + // format: off + val numbers = List( + "8005252225", + "800 525 2225", + "(800) 525-2225", + "800.525.2225") + // format: on + + val parsed = numbers.flatMap(PhoneNumber.parse) + + parsed should have size numbers.size + parsed should contain only PhoneNumber("1", "8005252225") + } + + it should "recognize CN numbers in international format" in { + PhoneNumber.parse("+868005252225") shouldBe Some(PhoneNumber("86", "8005252225")) + PhoneNumber.parse("+86 134 52 52 2256") shouldBe Some(PhoneNumber("86", "13452522256")) + } + + it should "return None on numbers that are shorter than the minimum number of digits for the country (i.e. US - 10, AR - 11)" in { + withClue("US and CN numbers are 10 digits - 9 digit (and shorter) numbers should not fit") { + // format: off + val numbers = List( + "+1 800 525-222", + "+1 800 525-2", + "+86 800 525-222", + "+86 800 525-2") + // format: on + + numbers.flatMap(PhoneNumber.parse) shouldBe empty + } + + withClue("Argentinian numbers are 11 digits (when prefixed with 0) - 10 digit numbers shouldn't fit") { + // format: off + val numbers = List( + "+54 011 525-22256", + "+54 011 525-2225", + "+54 011 525-222") + // format: on + + numbers.flatMap(PhoneNumber.parse) should contain theSameElementsAs List(PhoneNumber("54", "1152522256")) + } + } + + it should "return None on numbers that are longer than the maximum number of digits for the country (i.e. DK - 8, CN - 11)" in { + val numbers = List("+45 27 45 25 22", "+45 135 525 223", "+86 134 525 22256", "+86 135 525 22256 7") + + numbers.flatMap(PhoneNumber.parse) should contain theSameElementsAs + List(PhoneNumber("45", "27452522"), PhoneNumber("86", "13452522256")) + } + +} -- cgit v1.2.3