aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest/PatchSupport.scala
blob: 5ded61feead2fc98169e499462ffa2297695d7ad (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package xyz.driver.core.rest

import akka.http.javadsl.server.Rejections
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.server._
import spray.json._
import xyz.driver.core.Id

import scala.concurrent.Future
import scala.util.{Failure, Success, Try}
import scalaz.syntax.equal._
import scalaz.Scalaz.stringInstance

trait PatchSupport extends Directives with SprayJsonSupport {
  protected val MergePatchPlusJson: String = "merge-patch+json"

  trait PatchRetrievable[T] {
    def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]]
  }

  object PatchRetrievable {
    def apply[T](retriever: ((Id[T], ServiceRequestContext) => Future[Option[T]])): PatchRetrievable[T] =
      new PatchRetrievable[T] {
        override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id, ctx)
      }
  }

  protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = {
    JsObject(oldObj.fields.map({
      case (key, oldValue) =>
        val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1)))
        key -> newValue
    })(collection.breakOut): _*)
  }

  protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = {
    def mergeError(typ: String): Nothing =
      deserializationError(s"Expected $typ value, got $newValue")

    if (maxLevels.exists(_ < 0)) oldValue
    else {
      (oldValue, newValue) match {
        case (_: JsString, newString @ (JsString(_) | JsNull)) => newString
        case (_: JsString, _)                                  => mergeError("string")
        case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber
        case (_: JsNumber, _)                                  => mergeError("number")
        case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool
        case (_: JsBoolean, _)                                 => mergeError("boolean")
        case (_: JsArray, newArr @ (JsArray(_) | JsNull))      => newArr
        case (_: JsArray, _)                                   => mergeError("array")
        case (oldObj: JsObject, newObj: JsObject)              => mergeObjects(oldObj, newObj)
        case (_: JsObject, JsNull)                             => JsNull
        case (_: JsObject, _)                                  => mergeError("object")
        case (JsNull, _)                                       => newValue
      }
    }
  }

  def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx =>
    val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`]
    val isCorrectContentType =
      contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson)
    if (!isCorrectContentType) {
      reject(
        Rejections.malformedRequestContent(
          s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests",
          new RuntimeException))(requestCtx)
    } else inner(())(requestCtx)
  }

  def as[T](
      implicit patchable: PatchRetrievable[T],
      jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) =
    (patchable, jsonFormat)

  def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive {
    inner => requestCtx =>
      import requestCtx.executionContext
      val retriever                              = patchable._1
      implicit val jsonFormat: RootJsonFormat[T] = patchable._2
      Directives.patch {
        rejectNonMergePatchContentType {
          entity(as[JsValue]) { newValue =>
            serviceContext { implicit ctx =>
              onSuccess(retriever(id).map(_.map(_.toJson))) {
                case Some(oldValue) =>
                  val mergedObj = mergeJsValues(oldValue, newValue)
                  Try(mergedObj.convertTo[T])
                    .transform[Route](
                      mergedT => scala.util.Success(inner(Tuple1(mergedT))), {
                        case jsonException: DeserializationException =>
                          Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException)))
                        case t => Failure(t)
                      }
                    )
                    .get // intentionally re-throw all other errors
                case None =>
                  reject()
              }
            }
          }
        }
      }(requestCtx)
  }
}

object PatchSupport extends PatchSupport