aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZach Smith <zach@driver.xyz>2018-03-25 15:29:38 -0700
committerZach Smith <zach@driver.xyz>2018-03-25 15:29:38 -0700
commit33e491adc58b3ee3a37194339f09afa70d42a0e2 (patch)
tree6464053d059f4024f0c97559ba1de9eef30185cc
parent424e025f1c719006fe6a6669e43667e1d39ff076 (diff)
downloaddriver-core-33e491adc58b3ee3a37194339f09afa70d42a0e2.tar.gz
driver-core-33e491adc58b3ee3a37194339f09afa70d42a0e2.tar.bz2
driver-core-33e491adc58b3ee3a37194339f09afa70d42a0e2.zip
Use patch unmarshaller
-rw-r--r--src/main/scala/xyz/driver/core/rest/PatchDirectives.scala104
-rw-r--r--src/main/scala/xyz/driver/core/rest/PatchSupport.scala107
-rw-r--r--src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala (renamed from src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala)48
3 files changed, 129 insertions, 130 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala
new file mode 100644
index 0000000..256358c
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala
@@ -0,0 +1,104 @@
+package xyz.driver.core.rest
+
+import akka.http.javadsl.server.Rejections
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType}
+import akka.http.scaladsl.server._
+import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller}
+import spray.json._
+
+import scala.concurrent.Future
+import scala.util.{Failure, Success, Try}
+
+trait PatchDirectives extends Directives with SprayJsonSupport {
+
+ /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */
+ val `application/merge-patch+json`: MediaType.WithFixedCharset =
+ MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`)
+
+ /** Wraps a JSON value that represents a patch.
+ * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */
+ case class PatchValue(value: JsValue) {
+
+ /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */
+ def applyTo(original: JsValue): JsValue = mergeJsValues(original, value)
+ }
+
+ /** Witness that the given patch may be applied to an original domain value.
+ * @tparam A type of the domain value
+ * @param patch the patch that may be applied to a domain value
+ * @param format a JSON format that enables serialization and deserialization of a domain value */
+ case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) {
+
+ /** Applies the patch to a given domain object. The result will be a combination
+ * of the original value, updates with the fields specified in this witness' patch. */
+ def applyTo(original: A): A = {
+ val serialized = format.write(original)
+ val merged = patch.applyTo(serialized)
+ val deserialized = format.read(merged)
+ deserialized
+ }
+ }
+
+ implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] =
+ Unmarshaller.byteStringUnmarshaller
+ .andThen(sprayJsValueByteStringUnmarshaller)
+ .forContentTypes(ContentTypeRange(`application/merge-patch+json`))
+ .map(js => PatchValue(js))
+
+ implicit def patchableUnmarshaller[A](
+ implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue],
+ format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = {
+ patchUnmarshaller.map(patch => Patchable[A](patch, format))
+ }
+
+ protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = {
+ JsObject(oldObj.fields.map({
+ case (key, oldValue) =>
+ val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1)))
+ key -> newValue
+ })(collection.breakOut): _*)
+ }
+
+ protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = {
+ def mergeError(typ: String): Nothing =
+ deserializationError(s"Expected $typ value, got $newValue")
+
+ if (maxLevels.exists(_ < 0)) oldValue
+ else {
+ (oldValue, newValue) match {
+ case (_: JsString, newString @ (JsString(_) | JsNull)) => newString
+ case (_: JsString, _) => mergeError("string")
+ case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber
+ case (_: JsNumber, _) => mergeError("number")
+ case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool
+ case (_: JsBoolean, _) => mergeError("boolean")
+ case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr
+ case (_: JsArray, _) => mergeError("array")
+ case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj)
+ case (_: JsObject, JsNull) => JsNull
+ case (_: JsObject, _) => mergeError("object")
+ case (JsNull, _) => newValue
+ }
+ }
+ }
+
+ def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] =
+ Directive { inner => requestCtx =>
+ onSuccess(retrieve)({
+ case Some(oldT) =>
+ Try(patchable.applyTo(oldT))
+ .transform[Route](
+ mergedT => scala.util.Success(inner(Tuple1(mergedT))), {
+ case jsonException: DeserializationException =>
+ Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException)))
+ case t => Failure(t)
+ }
+ )
+ .get // intentionally re-throw all other errors
+ case None => reject()
+ })(requestCtx)
+ }
+}
+
+object PatchDirectives extends PatchDirectives
diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala
deleted file mode 100644
index d7c77c8..0000000
--- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-package xyz.driver.core.rest
-
-import akka.http.javadsl.server.Rejections
-import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
-import akka.http.scaladsl.server._
-import spray.json._
-import xyz.driver.core.Id
-
-import scala.concurrent.Future
-import scala.util.{Failure, Success, Try}
-import scalaz.syntax.equal._
-import scalaz.Scalaz.stringInstance
-
-trait PatchSupport extends Directives with SprayJsonSupport {
- protected val MergePatchPlusJson: String = "merge-patch+json"
-
- trait PatchRetrievable[T] {
- def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]]
- }
-
- object PatchRetrievable {
- def apply[T](retriever: (Id[T] => (ServiceRequestContext => Future[Option[T]]))): PatchRetrievable[T] =
- new PatchRetrievable[T] {
- override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id)(ctx)
- }
- }
-
- protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = {
- JsObject(oldObj.fields.map({
- case (key, oldValue) =>
- val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1)))
- key -> newValue
- })(collection.breakOut): _*)
- }
-
- protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = {
- def mergeError(typ: String): Nothing =
- deserializationError(s"Expected $typ value, got $newValue")
-
- if (maxLevels.exists(_ < 0)) oldValue
- else {
- (oldValue, newValue) match {
- case (_: JsString, newString @ (JsString(_) | JsNull)) => newString
- case (_: JsString, _) => mergeError("string")
- case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber
- case (_: JsNumber, _) => mergeError("number")
- case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool
- case (_: JsBoolean, _) => mergeError("boolean")
- case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr
- case (_: JsArray, _) => mergeError("array")
- case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj)
- case (_: JsObject, JsNull) => JsNull
- case (_: JsObject, _) => mergeError("object")
- case (JsNull, _) => newValue
- }
- }
- }
-
- def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx =>
- val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`]
- val isCorrectContentType =
- contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson)
- if (!isCorrectContentType) {
- reject(
- Rejections.malformedRequestContent(
- s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests",
- new RuntimeException))(requestCtx)
- } else inner(())(requestCtx)
- }
-
- def as[T](
- implicit patchable: PatchRetrievable[T],
- jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) =
- (patchable, jsonFormat)
-
- def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive {
- inner => requestCtx =>
- import requestCtx.executionContext
- val retriever = patchable._1
- implicit val jsonFormat: RootJsonFormat[T] = patchable._2
- Directives.patch {
- rejectNonMergePatchContentType {
- entity(as[JsValue]) { newValue =>
- serviceContext { implicit ctx =>
- onSuccess(retriever(id).map(_.map(_.toJson))) {
- case Some(oldValue) =>
- val mergedObj = mergeJsValues(oldValue, newValue)
- Try(mergedObj.convertTo[T])
- .transform[Route](
- mergedT => scala.util.Success(inner(Tuple1(mergedT))), {
- case jsonException: DeserializationException =>
- Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException)))
- case t => Failure(t)
- }
- )
- .get // intentionally re-throw all other errors
- case None =>
- reject()
- }
- }
- }
- }
- }(requestCtx)
- }
-}
-
-object PatchSupport extends PatchSupport
diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
index 5c7faf8..6a6b035 100644
--- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala
+++ b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala
@@ -12,9 +12,9 @@ import xyz.driver.core.json._
import scala.concurrent.Future
-class PatchSupportTest
+class PatchDirectivesTest
extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol
- with Directives with PatchSupport {
+ with Directives with PatchDirectives {
case class Bar(name: Name[Bar], size: Int)
case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar])
implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar)
@@ -22,68 +22,70 @@ class PatchSupportTest
val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10)))
- def route(implicit patchRetrievable: PatchRetrievable[Foo]): Route =
+ def route(retrieve: => Future[Option[Foo]]): Route =
Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId =>
- patch(as[Foo], fooId) { patchedFoo =>
- complete(patchedFoo)
+ entity(as[Patchable[Foo]]) { fooPatchable =>
+ mergePatch(fooPatchable, retrieve) { updatedFoo =>
+ complete(updatedFoo)
+ }
}
})
- def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json)
-
- val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get)
+ val MergePatchContentType = ContentType(`application/merge-patch+json`)
+ val ContentTypeHeader = `Content-Type`(MergePatchContentType)
+ def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity =
+ HttpEntity(contentType, json)
"PatchSupport" should "allow partial updates to an existing object" in {
- implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id))))
+ val fooRetrieve = Future.successful(Some(testFoo))
- Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check {
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
handled shouldBe true
responseAs[Foo] shouldBe testFoo.copy(rank = 4)
}
}
it should "merge deeply nested objects" in {
- implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id))))
+ val fooRetrieve = Future.successful(Some(testFoo))
- Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}"""))
- .withHeaders(ContentTypeHeader) ~> route ~> check {
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check {
handled shouldBe true
responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10)))
}
}
it should "return a 404 if the object is not found" in {
- implicit val fooPatchable = PatchRetrievable[Foo](_ => _ => Future.successful(None))
+ val fooRetrieve = Future.successful(None)
- Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check {
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check {
handled shouldBe true
status shouldBe StatusCodes.NotFound
}
}
it should "handle nulls on optional values correctly" in {
- implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id))))
+ val fooRetrieve = Future.successful(Some(testFoo))
- Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check {
+ Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check {
handled shouldBe true
responseAs[Foo] shouldBe testFoo.copy(bar = None)
}
}
it should "return a 400 for nulls on non-optional values" in {
- implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id))))
+ val fooRetrieve = Future.successful(Some(testFoo))
- Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check {
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check {
handled shouldBe true
status shouldBe StatusCodes.BadRequest
}
}
- it should "return a 400 for incorrect Content-Type" in {
- implicit val fooPatchable = PatchRetrievable[Foo](id => _ => Future.successful(Some(testFoo.copy(id = id))))
+ it should "return a 415 for incorrect Content-Type" in {
+ val fooRetrieve = Future.successful(Some(testFoo))
- Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check {
- status shouldBe StatusCodes.BadRequest
+ Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check {
+ status shouldBe StatusCodes.UnsupportedMediaType
responseAs[String] should include("application/merge-patch+json")
}
}