From a42cceae99ca8517ecff77fecdb23eba4d2c1036 Mon Sep 17 00:00:00 2001 From: Jon Pretty Date: Sun, 5 Nov 2017 19:57:30 +0000 Subject: Deduplication within case class parameter typeclasses --- core/src/main/scala/magnolia.scala | 119 +++++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 44 deletions(-) (limited to 'core/src/main/scala/magnolia.scala') diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala index 2d2bbcd..50d3ab5 100644 --- a/core/src/main/scala/magnolia.scala +++ b/core/src/main/scala/magnolia.scala @@ -48,8 +48,9 @@ object JoinContext { } } -abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, val parameters: Array[Param[Tc, T]]) { +abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, params: Array[Param[Tc, T]]) { def construct(param: ((Param[Tc, T]) => Any)): T + def parameters: Seq[Param[Tc, T]] = params } object Magnolia { @@ -65,6 +66,8 @@ object Magnolia { def findType(key: Type): Option[TermName] = recursionStack(c.enclosingPosition).frames.find(_.genericType == key).map(_.termName(c)) + case class Typeclass(typ: c.Type, tree: c.Tree) + def recurse[T](path: TypePath, key: Type, value: TermName)(fn: => T): Option[T] = { recursionStack = recursionStack.updated( @@ -89,11 +92,11 @@ object Magnolia { } } - def typeclassTree(paramName: Option[String], - genericType: Type, - typeConstructor: Type, - assignedName: TermName): Tree = { + def typeclassTree(paramName: Option[String], genericType: Type, typeConstructor: Type, + assignedName: TermName): Tree = { + val searchType = appliedType(typeConstructor, genericType) + findType(genericType).map { methodName => val methodAsString = methodName.encodedName.toString q"_root_.magnolia.Deferred.apply[$searchType]($methodAsString)" @@ -105,7 +108,7 @@ object Magnolia { recurse(ChainedImplicit(genericType.toString), genericType, assignedName) { c.inferImplicitValue(searchType, false, false) }.get - }.toOption.orElse(directInferImplicit(genericType, typeConstructor)) + }.toOption.orElse(directInferImplicit(genericType, typeConstructor).map(_.tree)) } recursionStack = recursionStack.updated(c.enclosingPosition, newStack) inferredImplicit @@ -123,7 +126,7 @@ object Magnolia { } def directInferImplicit(genericType: c.Type, - typeConstructor: Type): Option[c.Tree] = { + typeConstructor: Type): Option[Typeclass] = { val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) @@ -138,51 +141,72 @@ object Magnolia { // FIXME: Handle AnyVals val result = if(isCaseObject) { - val termSym = genericType.typeSymbol.companionSymbol - val obj = termSym.asTerm + val obj = genericType.typeSymbol.companion.asTerm val className = obj.name.toString val impl = q""" - ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, _root_.scala.Array(), $obj)) + ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, new _root_.scala.Array(0), $obj)) """ - Some(impl) + Some(Typeclass(genericType, impl)) } else if(isCaseClass) { val caseClassParameters = genericType.decls.collect { case m: MethodSymbol if m.isCaseAccessor => m.asMethod } val className = genericType.toString - val typeclasses: List[(c.universe.MethodSymbol, c.Tree, c.Type)] = caseClassParameters.map { param => + case class CaseParam(sym: c.universe.MethodSymbol, typeclass: c.Tree, paramType: c.Type, ref: c.TermName) + + val caseParams: List[CaseParam] = caseClassParameters.foldLeft(List[CaseParam]()) { case (acc, param) => val paramName = param.name.encodedName.toString val paramType = param.returnType.substituteTypes(genericType.etaExpand.typeParams, genericType.typeArgs) - val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, - assignedName) { - - typeclassTree(Some(paramName), paramType, typeConstructor, assignedName) - - }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")) - - (param, derivedImplicit, paramType) - }.to[List] + acc.find(_.paramType == paramType).map { backRef => + CaseParam(param, q"()", paramType, backRef.ref) :: acc + }.getOrElse { + val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, + assignedName) { + typeclassTree(Some(paramName), paramType, typeConstructor, assignedName) + }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")) + + val ref = TermName(c.freshName("paramTypeclass")) + val assigned = q"""val $ref = $derivedImplicit""" + CaseParam(param, assigned, paramType, ref) :: acc + } + }.to[List].reverse - val callables = typeclasses.map { case (param, typeclass, paramType) => - q"""_root_.magnolia.Param[$typeConstructor, $genericType, $paramType](${param.name.toString}, $typeclass, p => p.${TermName(param.name.toString)})""" + val paramsVal: TermName = TermName(c.freshName("parameters")) + val fnVal: TermName = TermName(c.freshName("fn")) + + val preAssignments = caseParams.map(_.typeclass) + + val assignments = caseParams.zipWithIndex.map { case (CaseParam(param, typeclass, paramType, ref), idx) => + q"""$paramsVal($idx) = _root_.magnolia.Param[$typeConstructor, $genericType, $paramType]( + ${param.name.toString}, $ref, _.${TermName(param.name.toString)} + )""" } - Some { - q""" - val parameters: _root_.scala.Array[Param[$typeConstructor, $genericType]] = _root_.scala.Array(..$callables) - ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, false, parameters, - (fn: (Param[$typeConstructor, $genericType] => Any)) => new $genericType(..${typeclasses.zipWithIndex.map { case (typeclass, idx) => - q"fn(parameters($idx)).asInstanceOf[${typeclass._3}]" - } }) + Some(Typeclass(genericType, + q"""{ + ..$preAssignments + val $paramsVal: _root_.scala.Array[Param[$typeConstructor, $genericType]] = + new _root_.scala.Array(${assignments.length}) + ..$assignments + + ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]( + $className, + false, + $paramsVal, + ($fnVal: Param[$typeConstructor, $genericType] => Any) => + new $genericType(..${caseParams.zipWithIndex.map { case (typeclass, idx) => + q"$fnVal($paramsVal($idx)).asInstanceOf[${typeclass.paramType}]" + } }) )) - """ - } + }""" + )) } else if(isSealedTrait) { val genericSubtypes = classType.get.knownDirectSubclasses.to[List] val subtypes = genericSubtypes.map { sub => - val mapping = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs.zip(genericType.typeArgs).toMap + val typeArgs = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs + val mapping = typeArgs.zip(genericType.typeArgs).toMap val newTypeParams = sub.asType.toType.typeArgs.map(mapping(_)) appliedType(sub.asType.toType.typeConstructor, newTypeParams) } @@ -193,15 +217,17 @@ object Magnolia { c.abort(c.enclosingPosition, "") } + + val subclassesVal: TermName = TermName(c.freshName("subclasses")) - val subclasses = subtypes.map { searchType => + val assignments = subtypes.map { searchType => recurse(CoproductType(genericType.toString), genericType, assignedName) { (searchType, typeclassTree(None, searchType, typeConstructor, assignedName)) }.getOrElse { c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType") } - }.map { case (typ, typeclass) => - q"""_root_.magnolia.Subclass[$typeConstructor, $genericType, $typ]( + }.zipWithIndex.map { case ((typ, typeclass), idx) => + q"""$subclassesVal($idx) = _root_.magnolia.Subclass[$typeConstructor, $genericType, $typ]( ${typ.typeSymbol.name.toString}, $typeclass, (t: $genericType) => t.isInstanceOf[$typ], @@ -210,17 +236,22 @@ object Magnolia { } Some { - q"""{ - ${c.prefix}.split(_root_.scala.collection.immutable.List[_root_.magnolia.Subclass[$typeConstructor, $genericType]](..$subclasses)) - }""" + Typeclass(genericType, q"""{ + val $subclassesVal: _root_.scala.Array[_root_.magnolia.Subclass[$typeConstructor, $genericType]] = + new _root_.scala.Array(${assignments.size}) + + ..$assignments + + ${c.prefix}.dispatch($subclassesVal: _root_.scala.Seq[_root_.magnolia.Subclass[$typeConstructor, $genericType]]) + }""") } } else None - result.map { r => - q"""{ + result.map { case Typeclass(t, r) => + Typeclass(t, q"""{ def $assignedName: $resultType = $r $assignedName - }""" + }""") } } @@ -248,13 +279,13 @@ object Magnolia { val result: Option[Tree] = if(!currentStack.frames.isEmpty) { findType(genericType) match { case None => - directInferImplicit(genericType, typeConstructor) + directInferImplicit(genericType, typeConstructor).map(_.tree) case Some(enclosingRef) => val methodAsString = enclosingRef.toString val searchType = appliedType(typeConstructor, genericType) Some(q"_root_.magnolia.Deferred[$searchType]($methodAsString)") } - } else directInferImplicit(genericType, typeConstructor) + } else directInferImplicit(genericType, typeConstructor).map(_.tree) if(currentStack.frames.isEmpty) recursionStack = ListMap() -- cgit v1.2.3