From 649a22c469647f095b93082a00c01b44fa2a6570 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 6 Mar 2018 21:39:36 -0800 Subject: Generate RootJsonFormats and remove special handling of case objects --- src/main/scala/DerivedFormats.scala | 46 +++++++++++++++++++++---------- src/test/scala/CoproductTypeFormats.scala | 16 +++++++++-- src/test/scala/FormatTests.scala | 2 +- src/test/scala/ProductTypeFormats.scala | 6 ++-- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/src/main/scala/DerivedFormats.scala b/src/main/scala/DerivedFormats.scala index 93e2640..cbfbd61 100644 --- a/src/main/scala/DerivedFormats.scala +++ b/src/main/scala/DerivedFormats.scala @@ -12,24 +12,22 @@ trait DerivedFormats { self: BasicFormats => def combine[T](ctx: CaseClass[JsonFormat, T]): JsonFormat[T] = new JsonFormat[T] { - override def write(value: T): JsValue = - if (ctx.isObject) { - JsString(ctx.typeName.short) - } else { - val fields: Seq[(String, JsValue)] = ctx.parameters.map { param => - param.label -> param.typeclass.write(param.dereference(value)) - } - JsObject(fields: _*) + override def write(value: T): JsValue = { + val fields: Seq[(String, JsValue)] = ctx.parameters.map { param => + param.label -> param.typeclass.write(param.dereference(value)) } + JsObject(fields: _*) + } override def read(value: JsValue): T = value match { case obj: JsObject => - ctx.construct { param => - param.typeclass.read(obj.fields(param.label)) + if (ctx.isObject) { + ctx.rawConstruct(Seq.empty) + } else { + ctx.construct { param => + param.typeclass.read(obj.fields(param.label)) + } } - case JsString(str) if ctx.isObject && str == ctx.typeName.short => - ctx.rawConstruct(Seq.empty) - case js => deserializationError( s"Cannot read JSON '$js' as a ${ctx.typeName.full}") @@ -77,8 +75,28 @@ trait DerivedFormats { self: BasicFormats => } } - implicit def gen[T]: JsonFormat[T] = macro Magnolia.gen[T] + implicit def derivedFormat[T]: RootJsonFormat[T] = + macro DerivedFormatHelper.derivedFormat[T] } object DerivedFormats extends DerivedFormats with BasicFormats + +object DerivedFormatHelper { + import scala.reflect.macros.whitebox._ + + /** Utility that converts a magnolia-generated JsonFormat to a RootJsonFormat. */ + def derivedFormat[T: c.WeakTypeTag](c: Context): c.Tree = { + import c.universe._ + val tpe = weakTypeOf[T].typeSymbol.asType + val sprayPkg = c.mirror.staticPackage("spray.json") + val valName = TermName(c.freshName("format")) + q"""{ + val $valName = ${Magnolia.gen[T](c)} + new $sprayPkg.RootJsonFormat[$tpe] { + def write(value: $tpe) = $valName.write(value) + def read(value: $sprayPkg.JsValue) = $valName.read(value) + } + }""" + } +} diff --git a/src/test/scala/CoproductTypeFormats.scala b/src/test/scala/CoproductTypeFormats.scala index cdf0201..e20871c 100644 --- a/src/test/scala/CoproductTypeFormats.scala +++ b/src/test/scala/CoproductTypeFormats.scala @@ -28,12 +28,12 @@ class CoproductTypeFormats "Nested parameter case class child" should behave like checkCoherence[Expr]( Plus(Value(42), One), - """{"type":"Plus","lhs":{"type":"Value","x":42},"rhs":"One"}""" + """{"type":"Plus","lhs":{"type":"Value","x":42},"rhs":{"type":"One"}}""" ) "Case object child" should behave like checkCoherence[Expr]( One, - """"One"""" + """{"type": "One"}""" ) @gadt("kind") @@ -45,13 +45,23 @@ class CoproductTypeFormats """{"kind":"If","type":"class"}""" ) + @gadt("""_`crazy type!`"""") + sealed abstract trait Crazy + case class CrazyType() extends Crazy + + "GADT with special characters in type field" should behave like checkCoherence[ + Crazy]( + CrazyType(), + """{"_`crazy type!`\"": "CrazyType"}""" + ) + sealed trait Enum case object A extends Enum case object B extends Enum "Enum" should behave like checkCoherence[List[Enum]]( A :: B :: Nil, - """["A", "B"]""" + """[{"type":"A"}, {"type":"B"}]""" ) "Serializing as sealed trait an deserializing as child" should "work" in { diff --git a/src/test/scala/FormatTests.scala b/src/test/scala/FormatTests.scala index e29e49f..68a4765 100644 --- a/src/test/scala/FormatTests.scala +++ b/src/test/scala/FormatTests.scala @@ -5,7 +5,7 @@ import org.scalatest._ trait FormatTests { self: FlatSpec => - def checkCoherence[A: JsonFormat](a: A, expectedJson: String) = { + def checkCoherence[A: RootJsonFormat](a: A, expectedJson: String) = { it should "serialize to the expected JSON value" in { val expected: JsValue = expectedJson.parseJson assert(a.toJson == expected) diff --git a/src/test/scala/ProductTypeFormats.scala b/src/test/scala/ProductTypeFormats.scala index 9755198..48f1bf1 100644 --- a/src/test/scala/ProductTypeFormats.scala +++ b/src/test/scala/ProductTypeFormats.scala @@ -31,16 +31,16 @@ class ProductTypeFormats "Case object" should behave like checkCoherence( D, - """"D"""" + "{}" ) "Case object as parameter" should behave like checkCoherence( E(D), - """{"d":"D"}""" + """{"d":{}}""" ) // custom format for F, that inverts the value of parameter x - implicit val fFormat: JsonFormat[F] = new JsonFormat[F] { + implicit val fFormat: RootJsonFormat[F] = new RootJsonFormat[F] { override def write(f: F): JsValue = JsObject("x" -> JsNumber(-f.x)) override def read(js: JsValue): F = F(-js.asJsObject.fields("x").convertTo[Int]) -- cgit v1.2.3