From 33e491adc58b3ee3a37194339f09afa70d42a0e2 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Sun, 25 Mar 2018 15:29:38 -0700 Subject: Use patch unmarshaller --- .../xyz/driver/core/rest/PatchDirectives.scala | 104 ++++++++++++++++++++ .../scala/xyz/driver/core/rest/PatchSupport.scala | 107 --------------------- .../xyz/driver/core/rest/PatchDirectivesTest.scala | 92 ++++++++++++++++++ .../xyz/driver/core/rest/PatchSupportTest.scala | 90 ----------------- 4 files changed, 196 insertions(+), 197 deletions(-) create mode 100644 src/main/scala/xyz/driver/core/rest/PatchDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PatchSupport.scala create mode 100644 src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..256358c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject(oldObj.fields.map({ + case (key, oldValue) => + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala deleted file mode 100644 index d7c77c8..0000000 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ /dev/null @@ -1,107 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.javadsl.server.Rejections -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.server._ -import spray.json._ -import xyz.driver.core.Id - -import scala.concurrent.Future -import scala.util.{Failure, Success, Try} -import scalaz.syntax.equal._ -import scalaz.Scalaz.stringInstance - -trait PatchSupport extends Directives with SprayJsonSupport { - protected val MergePatchPlusJson: String = "merge-patch+json" - - trait PatchRetrievable[T] { - def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] - } - - object PatchRetrievable { - def apply[T](retriever: (Id[T] => (ServiceRequestContext => Future[Option[T]]))): PatchRetrievable[T] = - new PatchRetrievable[T] { - override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id)(ctx) - } - } - - protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { - JsObject(oldObj.fields.map({ - case (key, oldValue) => - val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) - key -> newValue - })(collection.breakOut): _*) - } - - protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { - def mergeError(typ: String): Nothing = - deserializationError(s"Expected $typ value, got $newValue") - - if (maxLevels.exists(_ < 0)) oldValue - else { - (oldValue, newValue) match { - case (_: JsString, newString @ (JsString(_) | JsNull)) => newString - case (_: JsString, _) => mergeError("string") - case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber - case (_: JsNumber, _) => mergeError("number") - case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool - case (_: JsBoolean, _) => mergeError("boolean") - case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr - case (_: JsArray, _) => mergeError("array") - case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) - case (_: JsObject, JsNull) => JsNull - case (_: JsObject, _) => mergeError("object") - case (JsNull, _) => newValue - } - } - } - - def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx => - val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`] - val isCorrectContentType = - contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson) - if (!isCorrectContentType) { - reject( - Rejections.malformedRequestContent( - s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests", - new RuntimeException))(requestCtx) - } else inner(())(requestCtx) - } - - def as[T]( - implicit patchable: PatchRetrievable[T], - jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) = - (patchable, jsonFormat) - - def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive { - inner => requestCtx => - import requestCtx.executionContext - val retriever = patchable._1 - implicit val jsonFormat: RootJsonFormat[T] = patchable._2 - Directives.patch { - rejectNonMergePatchContentType { - entity(as[JsValue]) { newValue => - serviceContext { implicit ctx => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => scala.util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() - } - } - } - } - }(requestCtx) - } -} - -object PatchSupport extends PatchSupport diff --git a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..6a6b035 --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,92 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala deleted file mode 100644 index 5c7faf8..0000000 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ /dev/null @@ -1,90 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.`Content-Type` -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import spray.json._ -import xyz.driver.core.{Id, Name} -import xyz.driver.core.json._ - -import scala.concurrent.Future - -class PatchSupportTest - extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol - with Directives with PatchSupport { - case class Bar(name: Name[Bar], size: Int) - case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) - implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) - implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) - - val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) - - def route(implicit patchRetrievable: PatchRetrievable[Foo]): Route = - Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => - patch(as[Foo], fooId) { patchedFoo => - complete(patchedFoo) - } - }) - - def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) - - val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get) - - "PatchSupport" should "allow partial updates to an existing object" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4) - } - } - - it should "merge deeply nested objects" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) - .withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 404 if the object is not found" in { - implicit val fooPatchable = PatchRetrievable[Foo](_ => _ => Future.successful(None)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - } - } - - it should "handle nulls on optional values correctly" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = None) - } - } - - it should "return a 400 for nulls on non-optional values" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - } - } - - it should "return a 400 for incorrect Content-Type" in { - implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id)))) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { - status shouldBe StatusCodes.BadRequest - responseAs[String] should include("application/merge-patch+json") - } - } -} -- cgit v1.2.3