diff options
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/PatchSupport.scala | 42 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala | 10 |
2 files changed, 31 insertions, 21 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala index 5fd9149..9a28de1 100644 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -11,7 +11,14 @@ import scala.concurrent.Future trait PatchSupport extends Directives with SprayJsonSupport { trait PatchRetrievable[T] { - def apply(id: Id[T]): Future[Option[T]] + def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] + } + + object PatchRetrievable { + def apply[T](retriever: ((Id[T], ServiceRequestContext) => Future[Option[T]])): PatchRetrievable[T] = + new PatchRetrievable[T] { + override def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] = retriever(id, ctx) + } } protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { @@ -57,21 +64,24 @@ trait PatchSupport extends Directives with SprayJsonSupport { implicit val jsonFormat: RootJsonFormat[T] = patchable._2 Directives.patch { entity(as[JsValue]) { newValue => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - util - .Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - util.Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => util.Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() + serviceContext { implicit ctx => + onSuccess(retriever(id).map(_.map(_.toJson))) { + case Some(oldValue) => + val mergedObj = mergeJsValues(oldValue, newValue) + util + .Try(mergedObj.convertTo[T]) + .transform[Route]( + mergedT => util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + util.Success( + reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => util.Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => + reject() + } } } }(requestCtx) diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala index dcb3a93..469ac84 100644 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -31,7 +31,7 @@ class PatchSupportTest def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) "PatchSupport" should "allow partial updates to an existing object" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { handled shouldBe true @@ -40,7 +40,7 @@ class PatchSupportTest } it should "merge deeply nested objects" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { handled shouldBe true @@ -49,7 +49,7 @@ class PatchSupportTest } it should "return a 404 if the object is not found" in { - implicit val fooPatchable: PatchRetrievable[Foo] = _ => Future.successful(Option.empty[Foo]) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(None)) Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { handled shouldBe true @@ -58,7 +58,7 @@ class PatchSupportTest } it should "handle nulls on optional values correctly" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { handled shouldBe true @@ -67,7 +67,7 @@ class PatchSupportTest } it should "return a 400 for nulls on non-optional values" in { - implicit val fooPatchable: PatchRetrievable[Foo] = id => Future.successful(Some(testFoo.copy(id = id))) + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { handled shouldBe true |