1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
package xyz.driver.core.rest
import akka.http.javadsl.server.Rejections
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.server.{Directive, Directive1, Directives, Route}
import spray.json._
import xyz.driver.core.Id
import scala.concurrent.Future
trait PatchSupport extends Directives with SprayJsonSupport {
trait PatchRetrievable[T] {
def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]]
}
object PatchRetrievable {
def apply[T](retriever: ((Id[T], ServiceRequestContext) => Future[Option[T]])): PatchRetrievable[T] =
new PatchRetrievable[T] {
override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id, ctx)
}
}
protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = {
JsObject(oldObj.fields.map({
case (key, oldValue) =>
val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1)))
key -> newValue
})(collection.breakOut): _*)
}
protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = {
def mergeError(typ: String): Nothing =
deserializationError(s"Expected $typ value, got $newValue")
if (maxLevels.exists(_ < 0)) oldValue
else {
(oldValue, newValue) match {
case (_: JsString, newString @ (JsString(_) | JsNull)) => newString
case (_: JsString, _) => mergeError("string")
case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber
case (_: JsNumber, _) => mergeError("number")
case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool
case (_: JsBoolean, _) => mergeError("boolean")
case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr
case (_: JsArray, _) => mergeError("array")
case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj)
case (_: JsObject, JsNull) => JsNull
case (_: JsObject, _) => mergeError("object")
case (JsNull, _) => newValue
}
}
}
def as[T](
implicit patchable: PatchRetrievable[T],
jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) =
(patchable, jsonFormat)
def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive {
inner => requestCtx =>
import requestCtx.executionContext
val retriever = patchable._1
implicit val jsonFormat: RootJsonFormat[T] = patchable._2
Directives.patch {
entity(as[JsValue]) { newValue =>
serviceContext { implicit ctx =>
onSuccess(retriever(id).map(_.map(_.toJson))) {
case Some(oldValue) =>
val mergedObj = mergeJsValues(oldValue, newValue)
util
.Try(mergedObj.convertTo[T])
.transform[Route](
mergedT => util.Success(inner(Tuple1(mergedT))), {
case jsonException: DeserializationException =>
util.Success(
reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException)))
case t => util.Failure(t)
}
)
.get // intentionally re-throw all other errors
case None =>
reject()
}
}
}
}(requestCtx)
}
}
object PatchSupport extends PatchSupport
|