aboutsummaryrefslogtreecommitdiff
path: root/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala
blob: f33bf9dbc2f74a5c0aee1b0290346c85b3b3b49f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package xyz.driver.core.rest

import akka.http.javadsl.server.Rejections
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType}
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller}
import spray.json._

import scala.concurrent.Future
import scala.util.{Failure, Success, Try}

trait PatchDirectives extends Directives with SprayJsonSupport {

  /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */
  val `application/merge-patch+json`: MediaType.WithFixedCharset =
    MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`)

  /** Wraps a JSON value that represents a patch.
    * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */
  case class PatchValue(value: JsValue) {

    /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */
    def applyTo(original: JsValue): JsValue = mergeJsValues(original, value)
  }

  /** Witness that the given patch may be applied to an original domain value.
    * @tparam A type of the domain value
    * @param patch the patch that may be applied to a domain value
    * @param format a JSON format that enables serialization and deserialization of a domain value */
  case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) {

    /** Applies the patch to a given domain object. The result will be a combination
      * of the original value, updates with the fields specified in this witness' patch. */
    def applyTo(original: A): A = {
      val serialized   = format.write(original)
      val merged       = patch.applyTo(serialized)
      val deserialized = format.read(merged)
      deserialized
    }
  }

  implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] =
    Unmarshaller.byteStringUnmarshaller
      .andThen(sprayJsValueByteStringUnmarshaller)
      .forContentTypes(ContentTypeRange(`application/merge-patch+json`))
      .map(js => PatchValue(js))

  implicit def patchableUnmarshaller[A](
      implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue],
      format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = {
    patchUnmarshaller.map(patch => Patchable[A](patch, format))
  }

  protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = {
    JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key =>
      val oldValue = oldObj.fields.getOrElse(key, JsNull)
      val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1)))
      key -> newValue
    })(collection.breakOut): _*)
  }

  protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = {
    def mergeError(typ: String): Nothing =
      deserializationError(s"Expected $typ value, got $newValue")

    if (maxLevels.exists(_ < 0)) oldValue
    else {
      (oldValue, newValue) match {
        case (_: JsString, newString @ (JsString(_) | JsNull)) => newString
        case (_: JsString, _)                                  => mergeError("string")
        case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber
        case (_: JsNumber, _)                                  => mergeError("number")
        case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool
        case (_: JsBoolean, _)                                 => mergeError("boolean")
        case (_: JsArray, newArr @ (JsArray(_) | JsNull))      => newArr
        case (_: JsArray, _)                                   => mergeError("array")
        case (oldObj: JsObject, newObj: JsObject)              => mergeObjects(oldObj, newObj)
        case (_: JsObject, JsNull)                             => JsNull
        case (_: JsObject, _)                                  => mergeError("object")
        case (JsNull, _)                                       => newValue
      }
    }
  }

  def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] =
    Directive { inner => requestCtx =>
      onSuccess(retrieve)({
        case Some(oldT) =>
          Try(patchable.applyTo(oldT))
            .transform[Route](
              mergedT => scala.util.Success(inner(Tuple1(mergedT))), {
                case jsonException: DeserializationException =>
                  Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException)))
                case t => Failure(t)
              }
            )
            .get // intentionally re-throw all other errors
        case None => reject()
      })(requestCtx)
    }
}

object PatchDirectives extends PatchDirectives