From a42cceae99ca8517ecff77fecdb23eba4d2c1036 Mon Sep 17 00:00:00 2001 From: Jon Pretty Date: Sun, 5 Nov 2017 19:57:30 +0000 Subject: Deduplication within case class parameter typeclasses --- core/src/main/scala/magnolia.scala | 119 +++++++++++++++++++----------- examples/src/main/scala/typeclasses.scala | 10 +-- tests/src/main/scala/tests.scala | 2 - 3 files changed, 80 insertions(+), 51 deletions(-) diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala index 2d2bbcd..50d3ab5 100644 --- a/core/src/main/scala/magnolia.scala +++ b/core/src/main/scala/magnolia.scala @@ -48,8 +48,9 @@ object JoinContext { } } -abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, val parameters: Array[Param[Tc, T]]) { +abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, params: Array[Param[Tc, T]]) { def construct(param: ((Param[Tc, T]) => Any)): T + def parameters: Seq[Param[Tc, T]] = params } object Magnolia { @@ -65,6 +66,8 @@ object Magnolia { def findType(key: Type): Option[TermName] = recursionStack(c.enclosingPosition).frames.find(_.genericType == key).map(_.termName(c)) + case class Typeclass(typ: c.Type, tree: c.Tree) + def recurse[T](path: TypePath, key: Type, value: TermName)(fn: => T): Option[T] = { recursionStack = recursionStack.updated( @@ -89,11 +92,11 @@ object Magnolia { } } - def typeclassTree(paramName: Option[String], - genericType: Type, - typeConstructor: Type, - assignedName: TermName): Tree = { + def typeclassTree(paramName: Option[String], genericType: Type, typeConstructor: Type, + assignedName: TermName): Tree = { + val searchType = appliedType(typeConstructor, genericType) + findType(genericType).map { methodName => val methodAsString = methodName.encodedName.toString q"_root_.magnolia.Deferred.apply[$searchType]($methodAsString)" @@ -105,7 +108,7 @@ object Magnolia { recurse(ChainedImplicit(genericType.toString), genericType, assignedName) { c.inferImplicitValue(searchType, false, false) }.get - }.toOption.orElse(directInferImplicit(genericType, typeConstructor)) + }.toOption.orElse(directInferImplicit(genericType, typeConstructor).map(_.tree)) } recursionStack = recursionStack.updated(c.enclosingPosition, newStack) inferredImplicit @@ -123,7 +126,7 @@ object Magnolia { } def directInferImplicit(genericType: c.Type, - typeConstructor: Type): Option[c.Tree] = { + typeConstructor: Type): Option[Typeclass] = { val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) @@ -138,51 +141,72 @@ object Magnolia { // FIXME: Handle AnyVals val result = if(isCaseObject) { - val termSym = genericType.typeSymbol.companionSymbol - val obj = termSym.asTerm + val obj = genericType.typeSymbol.companion.asTerm val className = obj.name.toString val impl = q""" - ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, _root_.scala.Array(), $obj)) + ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, new _root_.scala.Array(0), $obj)) """ - Some(impl) + Some(Typeclass(genericType, impl)) } else if(isCaseClass) { val caseClassParameters = genericType.decls.collect { case m: MethodSymbol if m.isCaseAccessor => m.asMethod } val className = genericType.toString - val typeclasses: List[(c.universe.MethodSymbol, c.Tree, c.Type)] = caseClassParameters.map { param => + case class CaseParam(sym: c.universe.MethodSymbol, typeclass: c.Tree, paramType: c.Type, ref: c.TermName) + + val caseParams: List[CaseParam] = caseClassParameters.foldLeft(List[CaseParam]()) { case (acc, param) => val paramName = param.name.encodedName.toString val paramType = param.returnType.substituteTypes(genericType.etaExpand.typeParams, genericType.typeArgs) - val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, - assignedName) { - - typeclassTree(Some(paramName), paramType, typeConstructor, assignedName) - - }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")) - - (param, derivedImplicit, paramType) - }.to[List] + acc.find(_.paramType == paramType).map { backRef => + CaseParam(param, q"()", paramType, backRef.ref) :: acc + }.getOrElse { + val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, + assignedName) { + typeclassTree(Some(paramName), paramType, typeConstructor, assignedName) + }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")) + + val ref = TermName(c.freshName("paramTypeclass")) + val assigned = q"""val $ref = $derivedImplicit""" + CaseParam(param, assigned, paramType, ref) :: acc + } + }.to[List].reverse - val callables = typeclasses.map { case (param, typeclass, paramType) => - q"""_root_.magnolia.Param[$typeConstructor, $genericType, $paramType](${param.name.toString}, $typeclass, p => p.${TermName(param.name.toString)})""" + val paramsVal: TermName = TermName(c.freshName("parameters")) + val fnVal: TermName = TermName(c.freshName("fn")) + + val preAssignments = caseParams.map(_.typeclass) + + val assignments = caseParams.zipWithIndex.map { case (CaseParam(param, typeclass, paramType, ref), idx) => + q"""$paramsVal($idx) = _root_.magnolia.Param[$typeConstructor, $genericType, $paramType]( + ${param.name.toString}, $ref, _.${TermName(param.name.toString)} + )""" } - Some { - q""" - val parameters: _root_.scala.Array[Param[$typeConstructor, $genericType]] = _root_.scala.Array(..$callables) - ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, false, parameters, - (fn: (Param[$typeConstructor, $genericType] => Any)) => new $genericType(..${typeclasses.zipWithIndex.map { case (typeclass, idx) => - q"fn(parameters($idx)).asInstanceOf[${typeclass._3}]" - } }) + Some(Typeclass(genericType, + q"""{ + ..$preAssignments + val $paramsVal: _root_.scala.Array[Param[$typeConstructor, $genericType]] = + new _root_.scala.Array(${assignments.length}) + ..$assignments + + ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]( + $className, + false, + $paramsVal, + ($fnVal: Param[$typeConstructor, $genericType] => Any) => + new $genericType(..${caseParams.zipWithIndex.map { case (typeclass, idx) => + q"$fnVal($paramsVal($idx)).asInstanceOf[${typeclass.paramType}]" + } }) )) - """ - } + }""" + )) } else if(isSealedTrait) { val genericSubtypes = classType.get.knownDirectSubclasses.to[List] val subtypes = genericSubtypes.map { sub => - val mapping = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs.zip(genericType.typeArgs).toMap + val typeArgs = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs + val mapping = typeArgs.zip(genericType.typeArgs).toMap val newTypeParams = sub.asType.toType.typeArgs.map(mapping(_)) appliedType(sub.asType.toType.typeConstructor, newTypeParams) } @@ -193,15 +217,17 @@ object Magnolia { c.abort(c.enclosingPosition, "") } + + val subclassesVal: TermName = TermName(c.freshName("subclasses")) - val subclasses = subtypes.map { searchType => + val assignments = subtypes.map { searchType => recurse(CoproductType(genericType.toString), genericType, assignedName) { (searchType, typeclassTree(None, searchType, typeConstructor, assignedName)) }.getOrElse { c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType") } - }.map { case (typ, typeclass) => - q"""_root_.magnolia.Subclass[$typeConstructor, $genericType, $typ]( + }.zipWithIndex.map { case ((typ, typeclass), idx) => + q"""$subclassesVal($idx) = _root_.magnolia.Subclass[$typeConstructor, $genericType, $typ]( ${typ.typeSymbol.name.toString}, $typeclass, (t: $genericType) => t.isInstanceOf[$typ], @@ -210,17 +236,22 @@ object Magnolia { } Some { - q"""{ - ${c.prefix}.split(_root_.scala.collection.immutable.List[_root_.magnolia.Subclass[$typeConstructor, $genericType]](..$subclasses)) - }""" + Typeclass(genericType, q"""{ + val $subclassesVal: _root_.scala.Array[_root_.magnolia.Subclass[$typeConstructor, $genericType]] = + new _root_.scala.Array(${assignments.size}) + + ..$assignments + + ${c.prefix}.dispatch($subclassesVal: _root_.scala.Seq[_root_.magnolia.Subclass[$typeConstructor, $genericType]]) + }""") } } else None - result.map { r => - q"""{ + result.map { case Typeclass(t, r) => + Typeclass(t, q"""{ def $assignedName: $resultType = $r $assignedName - }""" + }""") } } @@ -248,13 +279,13 @@ object Magnolia { val result: Option[Tree] = if(!currentStack.frames.isEmpty) { findType(genericType) match { case None => - directInferImplicit(genericType, typeConstructor) + directInferImplicit(genericType, typeConstructor).map(_.tree) case Some(enclosingRef) => val methodAsString = enclosingRef.toString val searchType = appliedType(typeConstructor, genericType) Some(q"_root_.magnolia.Deferred[$searchType]($methodAsString)") } - } else directInferImplicit(genericType, typeConstructor) + } else directInferImplicit(genericType, typeConstructor).map(_.tree) if(currentStack.frames.isEmpty) recursionStack = ListMap() diff --git a/examples/src/main/scala/typeclasses.scala b/examples/src/main/scala/typeclasses.scala index c79a6d7..aaa74fd 100644 --- a/examples/src/main/scala/typeclasses.scala +++ b/examples/src/main/scala/typeclasses.scala @@ -19,7 +19,7 @@ object Show { }.mkString(s"${context.typeName.split("\\.").last}(", ",", ")") } - def split[T](subclasses: List[Subclass[Typeclass, T]])(value: T): String = + def dispatch[T](subclasses: Seq[Subclass[Typeclass, T]])(value: T): String = subclasses.map { sub => sub.cast.andThen { value => sub.typeclass.show(sub.cast(value)) } }.reduce(_ orElse _)(value) @@ -38,7 +38,7 @@ object Eq { context.parameters.forall { param => param.typeclass.equal(param.dereference(value1), param.dereference(value2)) } } - def split[T](subclasses: List[Subclass[Eq, T]]): Eq[T] = new Eq[T] { + def dispatch[T](subclasses: Seq[Subclass[Eq, T]]): Eq[T] = new Eq[T] { def equal(value1: T, value2: T) = subclasses.map { case subclass => subclass.cast.andThen { value => subclass.typeclass.equal(subclass.cast(value1), subclass.cast(value2)) } @@ -58,7 +58,7 @@ object Default { def default = context.construct { param => param.typeclass.default } } - def split[T](subclasses: List[Subclass[Default, T]])(): Default[T] = new Default[T] { + def dispatch[T](subclasses: Seq[Subclass[Default, T]])(): Default[T] = new Default[T] { def default = subclasses.head.typeclass.default } @@ -73,7 +73,7 @@ object Decoder { def join[T](context: JoinContext[Decoder, T])(value: String): T = context.construct { param => param.typeclass.decode(value) } - def split[T](subclasses: List[Subclass[Decoder, T]])(param: String): T = + def dispatch[T](subclasses: Seq[Subclass[Decoder, T]])(param: String): T = subclasses.map { subclass => { case _ if decodes(subclass.typeclass, param) => subclass.typeclass.decode(param) }: PartialFunction[String, T] }.reduce(_ orElse _)(param) @@ -106,7 +106,7 @@ case class Cyrillic(б: Letter, в: Letter, г: Letter, д: Letter, ж: Letter, case class Latin(a: Letter, b: Letter, c: Letter, d: Letter, e: Letter, f: Letter, g: Letter, h: Letter, i: Letter, j: Letter, k: Letter, l: Letter, m: Letter) extends Alphabet case class Letter(name: String, phonetic: String) -//case class Country(name: String, language: Language, leader: Person) +case class Country(name: String, language: Language, leader: Person) case class Language(name: String, code: String, alphabet: Alphabet) //case class Person(name: String, dateOfBirth: Date) case class Date(year: Int, month: Month, day: Int) diff --git a/tests/src/main/scala/tests.scala b/tests/src/main/scala/tests.scala index f269643..3245e1b 100644 --- a/tests/src/main/scala/tests.scala +++ b/tests/src/main/scala/tests.scala @@ -25,8 +25,6 @@ object Tests extends TestApp { Show.generic[Person].show(Person("John Smith", 34)) }.assert(_ == "Person(name=John Smith,age=34)") - //Show.generic[Tree[String]] - test("serialize a Branch") { import magnolia.examples._ implicitly[Show[String, Branch[String]]].show(Branch(Leaf("LHS"), Leaf("RHS"))) -- cgit v1.2.3