From ce099ac3ba85f71adeba0bb8398b69dc7cd2e8d1 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Fri, 16 Mar 2018 18:47:30 +0100 Subject: Add PatchSupport trait and tests --- .../scala/xyz/driver/core/rest/PatchSupport.scala | 81 ++++++++++++++++++++++ .../xyz/driver/core/rest/PatchSupportTest.scala | 77 ++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 src/main/scala/xyz/driver/core/rest/PatchSupport.scala create mode 100644 src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala new file mode 100644 index 0000000..5fd9149 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -0,0 +1,81 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.server.{Directive, Directive1, Directives, Route} +import spray.json._ +import xyz.driver.core.Id + +import scala.concurrent.Future + +trait PatchSupport extends Directives with SprayJsonSupport { + + trait PatchRetrievable[T] { + def apply(id: Id[T]): Future[Option[T]] + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject(oldObj.fields.map({ + case (key, oldValue) => + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def as[T]( + implicit patchable: PatchRetrievable[T], + jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) = + (patchable, jsonFormat) + + def patch[T](patchable: (PatchRetrievable[T], RootJsonFormat[T]), id: Id[T]): Directive1[T] = Directive { + inner => requestCtx => + import requestCtx.executionContext + val retriever = patchable._1 + implicit val jsonFormat: RootJsonFormat[T] = patchable._2 + Directives.patch { + entity(as[JsValue]) { newValue => + onSuccess(retriever(id).map(_.map(_.toJson))) { + case Some(oldValue) => + val mergedObj = mergeJsValues(oldValue, newValue) + util + .Try(mergedObj.convertTo[T]) + .transform[Route]( + mergedT => util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + util.Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => util.Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => + reject() + } + } + }(requestCtx) + } +} + +object PatchSupport extends PatchSupport diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala new file mode 100644 index 0000000..dcb3a93 --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -0,0 +1,77 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypes, HttpEntity, RequestEntity, StatusCodes} +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchSupportTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchSupport { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(implicit patchRetrievable: PatchRetrievable[Foo]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + patch(as[Foo], fooId) { patchedFoo => + complete(patchedFoo) + } + }) + + def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + implicit val fooPatchable: PatchRetrievable[Foo] = _ => Future.successful(Option.empty[Foo]) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "return a 400 for nulls on non-optional values" in { + implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } +} -- cgit v1.2.3