aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/xyz/driver/core/json.scala48
-rw-r--r--src/test/scala/xyz/driver/core/JsonTest.scala38
-rw-r--r--src/test/scala/xyz/driver/core/TestTypes.scala14
3 files changed, 99 insertions, 1 deletions
diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala
index 66cae52..277543b 100644
--- a/src/main/scala/xyz/driver/core/json.scala
+++ b/src/main/scala/xyz/driver/core/json.scala
@@ -107,4 +107,52 @@ object json {
case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json)
}
}
+
+ class GadtJsonFormat[T: TypeTag](typeField: String,
+ typeValue: PartialFunction[T, String],
+ jsonFormat: PartialFunction[String, JsonFormat[_ <: T]])
+ extends RootJsonFormat[T] {
+
+ def write(value: T): JsValue = {
+
+ val valueType = typeValue.applyOrElse(value, { v: T =>
+ deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v)
+ })
+
+ val valueFormat =
+ jsonFormat.applyOrElse(valueType, { f: String =>
+ deserializationError(s"No Json format for this type of $valueType")
+ })
+
+ valueFormat.asInstanceOf[JsonFormat[T]].write(value) match {
+ case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType)))
+ case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object")
+ }
+ }
+
+ def read(json: JsValue): T = json match {
+ case JsObject(fields) =>
+ val valueJson = JsObject(fields.filterNot(_._1 == typeField))
+ fields(typeField) match {
+ case JsString(valueType) =>
+ val valueFormat = jsonFormat.applyOrElse(valueType, { t: String =>
+ deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}")
+ })
+ valueFormat.read(valueJson)
+ case _ =>
+ deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}")
+ }
+ case _ =>
+ deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json)
+ }
+ }
+
+ object GadtJsonFormat {
+
+ def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])(
+ jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = {
+
+ new GadtJsonFormat[T](typeField, typeValue, jsonFormat)
+ }
+ }
}
diff --git a/src/test/scala/xyz/driver/core/JsonTest.scala b/src/test/scala/xyz/driver/core/JsonTest.scala
index c113c59..eb8d5d8 100644
--- a/src/test/scala/xyz/driver/core/JsonTest.scala
+++ b/src/test/scala/xyz/driver/core/JsonTest.scala
@@ -1,9 +1,11 @@
package xyz.driver.core
import org.scalatest.{FlatSpec, Matchers}
-import xyz.driver.core.json.{EnumJsonFormat, ValueClassFormat}
+import xyz.driver.core.json.{EnumJsonFormat, GadtJsonFormat, ValueClassFormat}
import xyz.driver.core.revision.Revision
import xyz.driver.core.time.provider.SystemTimeProvider
+import spray.json._
+import xyz.driver.core.TestTypes.CustomGADT
class JsonTest extends FlatSpec with Matchers {
@@ -98,4 +100,38 @@ class JsonTest extends FlatSpec with Matchers {
parsedValue1 should be(referenceValue1)
parsedValue2 should be(referenceValue2)
}
+
+ "Json format for classes GADT" should "read and write correct JSON" in {
+
+ import CustomGADT._
+ import DefaultJsonProtocol._
+ implicit val case1Format = jsonFormat1(GadtCase1)
+ implicit val case2Format = jsonFormat1(GadtCase2)
+ implicit val case3Format = jsonFormat1(GadtCase3)
+
+ val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") {
+ case t1: CustomGADT.GadtCase1 => "case1"
+ case t2: CustomGADT.GadtCase2 => "case2"
+ case t3: CustomGADT.GadtCase3 => "case3"
+ } {
+ case "case1" => case1Format
+ case "case2" => case2Format
+ case "case3" => case3Format
+ }
+
+ val referenceValue1 = CustomGADT.GadtCase1("4")
+ val referenceValue2 = CustomGADT.GadtCase2("Hi!")
+
+ val writtenJson1 = format.write(referenceValue1)
+ writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson)
+
+ val writtenJson2 = format.write(referenceValue2)
+ writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson)
+
+ val parsedValue1 = format.read(writtenJson1)
+ val parsedValue2 = format.read(writtenJson2)
+
+ parsedValue1 should be(referenceValue1)
+ parsedValue2 should be(referenceValue2)
+ }
}
diff --git a/src/test/scala/xyz/driver/core/TestTypes.scala b/src/test/scala/xyz/driver/core/TestTypes.scala
new file mode 100644
index 0000000..bb25deb
--- /dev/null
+++ b/src/test/scala/xyz/driver/core/TestTypes.scala
@@ -0,0 +1,14 @@
+package xyz.driver.core
+
+object TestTypes {
+
+ sealed trait CustomGADT {
+ val field: String
+ }
+
+ object CustomGADT {
+ final case class GadtCase1(field: String) extends CustomGADT
+ final case class GadtCase2(field: String) extends CustomGADT
+ final case class GadtCase3(field: String) extends CustomGADT
+ }
+}