From eada4c48b3954241e130f5d9b5e7feebe8c1e3f2 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 14 Feb 2018 12:17:15 -0800 Subject: Prepare for releasing --- src/main/scala/CustomFormats.scala | 19 ----- src/main/scala/DerivedFormats.scala | 119 ++++++++++++++++++++---------- src/main/scala/annotations.scala | 33 +++++++++ src/main/scala/main.scala | 58 --------------- src/test/scala/CoproductTypeFormats.scala | 58 +++++++++++++++ src/test/scala/FormatTests.scala | 20 +++++ src/test/scala/ProductTypeFormats.scala | 54 ++++++++++++++ 7 files changed, 247 insertions(+), 114 deletions(-) delete mode 100644 src/main/scala/CustomFormats.scala create mode 100644 src/main/scala/annotations.scala delete mode 100644 src/main/scala/main.scala create mode 100644 src/test/scala/CoproductTypeFormats.scala create mode 100644 src/test/scala/FormatTests.scala create mode 100644 src/test/scala/ProductTypeFormats.scala (limited to 'src') diff --git a/src/main/scala/CustomFormats.scala b/src/main/scala/CustomFormats.scala deleted file mode 100644 index 3656e3b..0000000 --- a/src/main/scala/CustomFormats.scala +++ /dev/null @@ -1,19 +0,0 @@ - -import spray.json._ - -trait CustomFormats extends DefaultJsonProtocol { - - implicit val fooFormat: JsonFormat[Foo] = new JsonFormat[Foo] { - def read(number: JsValue) = number match { - case JsNumber(x) => Foo(-x.toInt) - case tpe => sys.error(s"no way I'm reading that type $tpe!") - } - def write(number: Foo) = JsNumber(-number.x) - } - - implicit val z: JsonFormat[B] = new JsonFormat[B] { - def read(x: JsValue) = B("gone") - def write(x: B) = JsObject("a" -> JsString("A")) - } - -} diff --git a/src/main/scala/DerivedFormats.scala b/src/main/scala/DerivedFormats.scala index 6c2396e..79c1e4d 100644 --- a/src/main/scala/DerivedFormats.scala +++ b/src/main/scala/DerivedFormats.scala @@ -1,55 +1,100 @@ +package xyz.driver.json + import magnolia._ import spray.json._ import scala.language.experimental.macros -trait JsonFormatDerivation extends DefaultJsonProtocol { +trait DerivedFormats { self: BasicFormats => type Typeclass[T] = JsonFormat[T] - def combine[T](ctx: CaseClass[JsonFormat, T]): JsonFormat[T] = new JsonFormat[T] { - override def write(value: T): JsValue = { - val fields: Seq[(String, JsValue)] = ctx.parameters.map { param => - param.label -> param.typeclass.write(param.dereference(value)) - } - JsObject(fields: _*) - } - override def read(value: JsValue): T = value match { - case obj: JsObject => - ctx.construct { param => - param.typeclass.read(obj.fields(param.label)) + def combine[T](ctx: CaseClass[JsonFormat, T]): JsonFormat[T] = + new JsonFormat[T] { + override def write(value: T): JsValue = + if (ctx.isObject) { + JsString(ctx.typeName.short) + } else { + val fields: Seq[(String, JsValue)] = ctx.parameters.map { param => + param.label -> param.typeclass.write(param.dereference(value)) + } + JsObject(fields: _*) } - case js => - deserializationError(s"Cannot read JSON '$js' as a ${ctx.typeName}") - } - } - def dispatch[T](ctx: SealedTrait[JsonFormat, T]): JsonFormat[T] = new JsonFormat[T] { - override def write(value: T): JsValue = { - ctx.dispatch(value) { sub => - val obj = sub.typeclass.write(sub.cast(value)).asJsObject - JsObject((obj.fields ++ Map("type" -> JsString(sub.label))).toSeq: _*) + override def read(value: JsValue): T = value match { + case obj: JsObject => + ctx.construct { param => + param.typeclass.read(obj.fields(param.label)) + } + case str: JsString if ctx.isObject && str.value == ctx.typeName.short => + ctx.rawConstruct(Seq.empty) + + case js => + deserializationError( + s"Cannot read JSON '$js' as a ${ctx.typeName.full}") } } - override def read(value: JsValue): T = value match { - case obj: JsObject if obj.fields.contains("type") => - val fieldName = obj.fields("type").convertTo[String] - - ctx.subtypes.find(_.label == fieldName) match { - case Some(tpe) => tpe.typeclass.read(obj) - case None => - deserializationError( - s"Cannot deserialize JSON to ${ctx.typeName} because type field '${fieldName}' has an unsupported value.") - } - case js => - deserializationError(s"Cannot read JSON '$js' as a ${ctx.typeName}") - } + def dispatch[T](ctx: SealedTrait[JsonFormat, T]): JsonFormat[T] = + new JsonFormat[T] { + def tpe = + ctx.annotations + .find(_.isInstanceOf[JsonAnnotation]) + .getOrElse(new gadt("type")) + + override def write(value: T): JsValue = tpe match { + case _: enum => + ctx.dispatch(value) { sub => + JsString(sub.typeName.short) + } - } + case g: gadt => + ctx.dispatch(value) { sub => + val obj = sub.typeclass.write(sub.cast(value)).asJsObject + JsObject( + (Map(g.typeFieldName -> JsString(sub.typeName.short)) ++ + obj.fields).toSeq: _*) + } + } + + override def read(value: JsValue): T = tpe match { + case _: enum => + value match { + case str: JsString => + ctx.subtypes + .find(_.typeName.short == str.value) + .getOrElse(deserializationError( + s"Cannot deserialize JSON to ${ctx.typeName.full} because " + + "type '${str}' has an unsupported value.")) + .typeclass + .read(str) + case js => + deserializationError( + s"Cannot read JSON '$js' as a ${ctx.typeName.full}") + } + + case g: gadt => + value match { + case obj: JsObject if obj.fields.contains(g.typeFieldName) => + val fieldName = obj.fields(g.typeFieldName).convertTo[String] + + ctx.subtypes.find(_.typeName.short == fieldName) match { + case Some(tpe) => tpe.typeclass.read(obj) + case None => + deserializationError( + s"Cannot deserialize JSON to ${ctx.typeName.full} " + + s"because type field '${fieldName}' has an unsupported " + + "value.") + } + + case js => + deserializationError( + s"Cannot read JSON '$js' as a ${ctx.typeName}") + } + } + } implicit def gen[T]: JsonFormat[T] = macro Magnolia.gen[T] } -object JsonFormatDerivation extends JsonFormatDerivation -trait DerivedFormats extends JsonFormatDerivation +object DerivedFormats extends DerivedFormats with BasicFormats diff --git a/src/main/scala/annotations.scala b/src/main/scala/annotations.scala new file mode 100644 index 0000000..f23fbcb --- /dev/null +++ b/src/main/scala/annotations.scala @@ -0,0 +1,33 @@ +package xyz.driver.json + +import scala.annotation.StaticAnnotation + +/** Indicator trait of anontations related to JSON formatting. */ +sealed trait JsonAnnotation + +/** An annotation that designates that a sealed trait is a generalized algebraic + * datatype (GADT), and that a type field containing the serialized childrens' + * types should be added to the final JSON objects. + * + * Note that by default all sealed traits are treated as GADTs, with a type + * field called `type`. This annotation enables overriding the name of that + * field and is really only useful if a child itself has a field called `type` + * that would result in a conflict. + * + * Example + * ``` + * // the JSON field "kind" will contain the actual type of the serialized child + * @gadt("kind") sealed abstract class Keyword(`type`: String) + * case class If(`type`: String) extends Keyword(`type`) + * ``` + * + * @param typeFieldName the name of the field to inject into a serialized JSON + * object */ +final class gadt(val typeFieldName: String = "type") + extends StaticAnnotation + with JsonAnnotation + +/** An annotation that designates that a sealed trait is an enumeration (all + * children are strictly case objects), and that all children should be + * serialized as strings. */ +final class enum extends StaticAnnotation with JsonAnnotation diff --git a/src/main/scala/main.scala b/src/main/scala/main.scala deleted file mode 100644 index 223ac1c..0000000 --- a/src/main/scala/main.scala +++ /dev/null @@ -1,58 +0,0 @@ -import spray.json._ - -// product type -case class Foo(x: Int) -case class Bar(foo: Foo, str: String) - -// coproduct -sealed trait T -case object A extends T -case class B(a: String) extends T -case class C(x: T) extends T // inception! - -object Main extends App with DefaultJsonProtocol with DerivedFormats { - - println("//////////\nProducts:") - - { - val product = Bar(Foo(42), "hello world") - val js = product.toJson - println(js.prettyPrint) - println(js.convertTo[Bar]) - } - - println("//////////\nCoproducts:") - - { - val coproduct: T = B("hello wordld") //Seq(C(B("What's up?")), B("a"), A) - val js = coproduct.toJson - println(js.prettyPrint) - println(js.convertTo[T]) - } - -} - -/* -A potentital danger: - -Overriding generated formats is possible (see CustomFormats), however it can be -easy to forget to include the custom formats. -=> In that case, the program will still compile, however it won't use the - correct format! - -Possible workarounds? - - - Require explicit format declarations, i.e. remove implicit from `implicit def - gen[T] = macro Magnolia.gen[T]` and add `def myFormat = gen[Foo]` to every - format trait. - => requires manual code, thereby mostly defeats the advantages of automatic derivation - => (advantage, no more code duplication since macro is expanded only once) - - - Avoid custom formats. - => entities become "API objects", which will be hard to upgrade in a backwards-compatible, yet idiomatic way - (E.g. new fields could be made optional so that they won't be required in json, however the business logic - may not require them to be optional. We lose some typesafety.) - => we'd likely have an additional layer of indirection, that will convert "api objects" to "business objects" - implemented by services - => Is that a good or bad thing? -*/ diff --git a/src/test/scala/CoproductTypeFormats.scala b/src/test/scala/CoproductTypeFormats.scala new file mode 100644 index 0000000..f16e4c7 --- /dev/null +++ b/src/test/scala/CoproductTypeFormats.scala @@ -0,0 +1,58 @@ +package xyz.driver.json + +import spray.json._ + +import org.scalatest._ + +class CoproductTypeFormats + extends FlatSpec + with FormatTests + with DefaultJsonProtocol + with DerivedFormats { + + sealed trait Expr + case class Zero() extends Expr + case class Value(x: Int) extends Expr + case class Plus(lhs: Expr, rhs: Expr) extends Expr + case object One extends Expr + + "No-parameter case class child" should behave like checkCoherence[Expr]( + Zero(), + """{"type":"Zero"}""" + ) + + "Simple parameter case class child" should behave like checkCoherence[Expr]( + Value(42), + """{"type":"Value","x":42}""" + ) + + "Nested parameter case class child" should behave like checkCoherence[Expr]( + Plus(Value(42), Value(0)), + """{"type":"Plus","lhs":{"type":"Value","x":42},"rhs":{"type":"Value","x":0}}""" + ) + + // "Case object child" should behave like checkCoherence[Expr]( + // One, + // """{"type":"One"}""" + // ) + + @gadt("kind") + sealed abstract class Keyword(`type`: String) + case class If(`type`: String) extends Keyword(`type`) + + "GADT with type field alias" should behave like checkCoherence[Keyword]( + If("class"), + """{"kind":"If","type":"class"}""" + ) + + @enum + sealed trait Enum + case object A extends Enum + case object B extends Enum + + "Enum" should behave like checkCoherence[List[Enum]]( + A :: B :: Nil, + """["A", "B"]""" + ) + +} diff --git a/src/test/scala/FormatTests.scala b/src/test/scala/FormatTests.scala new file mode 100644 index 0000000..e29e49f --- /dev/null +++ b/src/test/scala/FormatTests.scala @@ -0,0 +1,20 @@ +package xyz.driver.json + +import spray.json._ +import org.scalatest._ + +trait FormatTests { self: FlatSpec => + + def checkCoherence[A: JsonFormat](a: A, expectedJson: String) = { + it should "serialize to the expected JSON value" in { + val expected: JsValue = expectedJson.parseJson + assert(a.toJson == expected) + } + + it should "serialize then deserialize back to itself" in { + val reread = a.toJson.convertTo[A] + assert(reread == a) + } + } + +} diff --git a/src/test/scala/ProductTypeFormats.scala b/src/test/scala/ProductTypeFormats.scala new file mode 100644 index 0000000..9755198 --- /dev/null +++ b/src/test/scala/ProductTypeFormats.scala @@ -0,0 +1,54 @@ +package xyz.driver.json + +import spray.json._ + +import org.scalatest._ + +class ProductTypeFormats + extends FlatSpec + with FormatTests + with DerivedFormats + with DefaultJsonProtocol { + + case class A() + case class B(x: Int, b: String, mp: Map[String, Int]) + case class C(b: B) + case object D + case class E(d: D.type) + case class F(x: Int) + + "No-parameter product" should behave like checkCoherence(A(), "{}") + + "Simple parameter product" should behave like checkCoherence( + B(42, "Hello World", Map("a" -> 1, "b" -> -1024)), + """{ "x": 42, "b": "Hello World", "mp": { "a": 1, "b": -1024 } }""" + ) + + "Nested parameter product" should behave like checkCoherence( + C(B(42, "Hello World", Map("a" -> 1, "b" -> -1024))), + """{"b" :{ "x": 42, "b": "Hello World", "mp": { "a": 1, "b": -1024 } } }""" + ) + + "Case object" should behave like checkCoherence( + D, + """"D"""" + ) + + "Case object as parameter" should behave like checkCoherence( + E(D), + """{"d":"D"}""" + ) + + // custom format for F, that inverts the value of parameter x + implicit val fFormat: JsonFormat[F] = new JsonFormat[F] { + override def write(f: F): JsValue = JsObject("x" -> JsNumber(-f.x)) + override def read(js: JsValue): F = + F(-js.asJsObject.fields("x").convertTo[Int]) + } + + "Overriding with a custom format" should behave like checkCoherence( + F(2), + """{"x":-2}""" + ) + +} -- cgit v1.2.3