From 33e491adc58b3ee3a37194339f09afa70d42a0e2 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Sun, 25 Mar 2018 15:29:38 -0700 Subject: Use patch unmarshaller --- .../xyz/driver/core/rest/PatchDirectives.scala | 104 +++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 src/main/scala/xyz/driver/core/rest/PatchDirectives.scala (limited to 'src/main/scala/xyz/driver/core/rest/PatchDirectives.scala') diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..256358c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject(oldObj.fields.map({ + case (key, oldValue) => + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives -- cgit v1.2.3