diff options
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/PatchSupport.scala | 54 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala | 25 |
2 files changed, 54 insertions, 25 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala index 9a28de1..5ded61f 100644 --- a/src/main/scala/xyz/driver/core/rest/PatchSupport.scala +++ b/src/main/scala/xyz/driver/core/rest/PatchSupport.scala @@ -2,13 +2,17 @@ package xyz.driver.core.rest import akka.http.javadsl.server.Rejections import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.server.{Directive, Directive1, Directives, Route} +import akka.http.scaladsl.server._ import spray.json._ import xyz.driver.core.Id import scala.concurrent.Future +import scala.util.{Failure, Success, Try} +import scalaz.syntax.equal._ +import scalaz.Scalaz.stringInstance trait PatchSupport extends Directives with SprayJsonSupport { + protected val MergePatchPlusJson: String = "merge-patch+json" trait PatchRetrievable[T] { def apply(id: Id[T])(implicit ctx: ServiceRequestContext): Future[Option[T]] @@ -52,6 +56,18 @@ trait PatchSupport extends Directives with SprayJsonSupport { } } + def rejectNonMergePatchContentType: Directive0 = Directive { inner => requestCtx => + val contentType = requestCtx.request.header[akka.http.scaladsl.model.headers.`Content-Type`] + val isCorrectContentType = + contentType.map(_.contentType.mediaType).exists(mt => mt.isApplication && mt.subType === MergePatchPlusJson) + if (!isCorrectContentType) { + reject( + Rejections.malformedRequestContent( + s"Request Content-Type must be application/$MergePatchPlusJson for PATCH requests", + new RuntimeException))(requestCtx) + } else inner(())(requestCtx) + } + def as[T]( implicit patchable: PatchRetrievable[T], jsonFormat: RootJsonFormat[T]): (PatchRetrievable[T], RootJsonFormat[T]) = @@ -63,24 +79,24 @@ trait PatchSupport extends Directives with SprayJsonSupport { val retriever = patchable._1 implicit val jsonFormat: RootJsonFormat[T] = patchable._2 Directives.patch { - entity(as[JsValue]) { newValue => - serviceContext { implicit ctx => - onSuccess(retriever(id).map(_.map(_.toJson))) { - case Some(oldValue) => - val mergedObj = mergeJsValues(oldValue, newValue) - util - .Try(mergedObj.convertTo[T]) - .transform[Route]( - mergedT => util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - util.Success( - reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => util.Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => - reject() + rejectNonMergePatchContentType { + entity(as[JsValue]) { newValue => + serviceContext { implicit ctx => + onSuccess(retriever(id).map(_.map(_.toJson))) { + case Some(oldValue) => + val mergedObj = mergeJsValues(oldValue, newValue) + Try(mergedObj.convertTo[T]) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => + reject() + } } } } diff --git a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala index 469ac84..20d667d 100644 --- a/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala +++ b/src/test/scala/xyz/driver/core/rest/PatchSupportTest.scala @@ -1,7 +1,8 @@ package xyz.driver.core.rest import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.{ContentTypes, HttpEntity, RequestEntity, StatusCodes} +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` import akka.http.scaladsl.server.{Directives, Route} import akka.http.scaladsl.testkit.ScalatestRouteTest import org.scalatest.{FlatSpec, Matchers} @@ -30,10 +31,12 @@ class PatchSupportTest def jsonEntity(json: String): RequestEntity = HttpEntity(ContentTypes.`application/json`, json) + val ContentTypeHeader = `Content-Type`(ContentType.parse("application/merge-patch+json").right.get) + "PatchSupport" should "allow partial updates to an existing object" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(rank = 4) } @@ -42,7 +45,8 @@ class PatchSupportTest it should "merge deeply nested objects" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) + .withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) } @@ -51,7 +55,7 @@ class PatchSupportTest it should "return a 404 if the object is not found" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(None)) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true status shouldBe StatusCodes.NotFound } @@ -60,7 +64,7 @@ class PatchSupportTest it should "handle nulls on optional values correctly" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true responseAs[Foo] shouldBe testFoo.copy(bar = None) } @@ -69,9 +73,18 @@ class PatchSupportTest it should "return a 400 for nulls on non-optional values" in { implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route ~> check { + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")).withHeaders(ContentTypeHeader) ~> route ~> check { handled shouldBe true status shouldBe StatusCodes.BadRequest } } + + it should "return a 400 for incorrect Content-Type" in { + implicit val fooPatchable = PatchRetrievable[Foo]((id, _) => Future.successful(Some(testFoo.copy(id = id)))) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route ~> check { + status shouldBe StatusCodes.BadRequest + responseAs[String] should include("application/merge-patch+json") + } + } } |