From e3fddd33620eb5d124d271b7f27859295ef2d267 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Fri, 5 Jan 2018 14:48:53 +0100 Subject: Rework compile time stack * Use a classic mutable stack (a case class without lenses is cumbersome) * Add typeclass constructor to stack frames, cache and error messages * Clean-up usage of `Option`s --- core/shared/src/main/scala/magnolia.scala | 322 ++++++++++++------------------ 1 file changed, 129 insertions(+), 193 deletions(-) (limited to 'core/shared/src/main/scala') diff --git a/core/shared/src/main/scala/magnolia.scala b/core/shared/src/main/scala/magnolia.scala index b5c196a..3135b57 100644 --- a/core/shared/src/main/scala/magnolia.scala +++ b/core/shared/src/main/scala/magnolia.scala @@ -1,11 +1,9 @@ package magnolia -import scala.util.control.NonFatal -import scala.reflect._ -import macros._ -import scala.collection.immutable.ListMap -import language.existentials -import language.higherKinds +import scala.collection.mutable +import scala.language.existentials +import scala.language.higherKinds +import scala.reflect.macros._ /** the object which defines the Magnolia macro */ object Magnolia { @@ -65,7 +63,7 @@ object Magnolia { * will suffice, however the qualifications regarding additional type parameters and implicit * parameters apply equally to `dispatch` as to `combine`. * */ - def gen[T: c.WeakTypeTag](c: whitebox.Context): c.Tree = { + def gen[T: c.WeakTypeTag](c: whitebox.Context): c.Tree = Stack.withContext(c) { stack => import c.universe._ import internal._ @@ -73,13 +71,15 @@ object Magnolia { .find(_.tree.tpe <:< typeOf[debug]) .flatMap(_.tree.children.tail.collectFirst { case Literal(Constant(s: String)) => s }) - val magnoliaPkg = q"_root_.magnolia" - val scalaPkg = q"_root_.scala" + val magnoliaPkg = c.mirror.staticPackage("magnolia") + val scalaPkg = c.mirror.staticPackage("scala") val repeatedParamClass = definitions.RepeatedParamClass val scalaSeqType = typeOf[Seq[_]].typeConstructor val prefixType = c.prefix.tree.tpe + val prefixObject = prefixType.typeSymbol + val prefixName = prefixObject.name.decodedName val typeDefs = prefixType.baseClasses.flatMap { cls => cls.asType.toType.decls.filter(_.isType).find(_.name.toString == "Typeclass").map { tpe => @@ -87,13 +87,10 @@ object Magnolia { } } - val typeConstructorOpt = - typeDefs.headOption.map(_.typeConstructor) - - val typeConstructor = typeConstructorOpt.getOrElse { + val typeConstructor = typeDefs.headOption.fold { c.abort(c.enclosingPosition, - "magnolia: the derivation object does not define the Typeclass type constructor") - } + s"magnolia: the derivation $prefixObject does not define the Typeclass type constructor") + } (_.typeConstructor) def checkMethod(termName: String, category: String, expected: String): Unit = { val term = TermName(termName) @@ -104,7 +101,7 @@ object Magnolia { .getOrElse { c.abort( c.enclosingPosition, - s"magnolia: the method `$termName` must be defined on the derivation object to derive typeclasses for $category" + s"magnolia: the method `$termName` must be defined on the derivation $prefixObject to derive typeclasses for $category" ) } val firstParamBlock = combineClass.asType.toType.decl(term).asTerm.asMethod.paramLists.head @@ -117,91 +114,43 @@ object Magnolia { checkMethod("combine", "case classes", "CaseClass[Typeclass, _]") checkMethod("dispatch", "sealed traits", "SealedTrait[Typeclass, _]") - def findType(key: Type): Option[TermName] = - recursionStack(c.enclosingPosition).frames.find(_.genericType == key).map(_.termName(c)) - - final case class Typeclass(typ: c.Type, tree: c.Tree) - - def recurse[A](path: TypePath, key: Type, value: TermName)(fn: => A): Option[A] = { - val oldRecursionStack = recursionStack.get(c.enclosingPosition) - recursionStack = recursionStack.updated( - c.enclosingPosition, - oldRecursionStack.map(_.push(path, key, value)).getOrElse { - Stack(Map(), List(Frame(path, key, value)), Nil) - } - ) - - try Some(fn) - catch { case NonFatal(_) => None } finally { - val currentStack = recursionStack(c.enclosingPosition) - recursionStack = recursionStack.updated(c.enclosingPosition, currentStack.pop()) - } - } - - val removeDeferred: Transformer = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case q"$magnoliaPkg.Deferred.apply[$returnType](${Literal(Constant(method: String))})" => - q"${TermName(method)}" + val removeDeferred = new Transformer { + override def transform(tree: Tree) = tree match { + case q"$magnolia.Deferred.apply[$_](${Literal(Constant(method: String))})" + if magnolia.symbol == magnoliaPkg => q"${TermName(method)}" case _ => super.transform(tree) } } - def typeclassTree(paramName: Option[String], - genericType: Type, - typeConstructor: Type, - assignedName: TermName): Tree = { - + def typeclassTree(genericType: Type, typeConstructor: Type): Tree = { val searchType = appliedType(typeConstructor, genericType) - - val deferredRef = findType(genericType).map { methodName => + val deferredRef = for (methodName <- stack find searchType) yield { val methodAsString = methodName.decodedName.toString q"$magnoliaPkg.Deferred.apply[$searchType]($methodAsString)" } - val foundImplicit = deferredRef.orElse { - val (inferredImplicit, newStack) = - recursionStack(c.enclosingPosition).lookup(c)(searchType) { - val implicitSearchTry = scala.util.Try { - val genericTypeName: String = - genericType.typeSymbol.name.decodedName.toString.toLowerCase - - val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) - - recurse(ChainedImplicit(genericType.toString), genericType, assignedName) { - c.inferImplicitValue(searchType, false, false) - }.get + deferredRef.getOrElse { + val path = ChainedImplicit(s"$prefixName.Typeclass", genericType.toString) + val frame = stack.Frame(path, searchType, termNames.EMPTY) + stack.recurse(frame, searchType) { + Option(c.inferImplicitValue(searchType)).filterNot(_.isEmpty) + .orElse(directInferImplicit(genericType, typeConstructor)) + .getOrElse { + val missingType = stack.top.fold(searchType)(_.searchType.asInstanceOf[Type]) + val typeClassName = s"${missingType.typeSymbol.name.decodedName}.Typeclass" + val genericType = missingType.typeArgs.head + val trace = stack.trace.mkString(" in ", "\n in ", "\n") + c.abort(c.enclosingPosition, + s"magnolia: could not find $typeClassName for type $genericType\n$trace") } - - implicitSearchTry.toOption.orElse( - directInferImplicit(genericType, typeConstructor).map(_.tree) - ) - } - recursionStack = recursionStack.updated(c.enclosingPosition, newStack) - inferredImplicit - } - - foundImplicit.getOrElse { - val currentStack: Stack = recursionStack(c.enclosingPosition) - - val error = ImplicitNotFound(genericType.toString, - recursionStack(c.enclosingPosition).frames.map(_.path)) - - val updatedStack = currentStack.copy(errors = error :: currentStack.errors) - recursionStack = recursionStack.updated(c.enclosingPosition, updatedStack) - - val stackPaths = recursionStack(c.enclosingPosition).frames.map(_.path) - val stack = stackPaths.mkString(" in ", "\n in ", "\n") - - c.abort(c.enclosingPosition, - s"magnolia: could not find typeclass for type $genericType\n$stack") + } } } - def directInferImplicit(genericType: c.Type, typeConstructor: Type): Option[Typeclass] = { - - val genericTypeName: String = genericType.typeSymbol.name.decodedName.toString.toLowerCase - val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) + def directInferImplicit(genericType: Type, typeConstructor: Type): Option[Tree] = { + val genericTypeName = genericType.typeSymbol.name.decodedName.toString.toLowerCase + val assignedName = TermName(c.freshName(s"${genericTypeName}Typeclass")) val typeSymbol = genericType.typeSymbol val classType = if (typeSymbol.isClass) Some(typeSymbol.asClass) else None val isCaseClass = classType.exists(_.isCaseClass) @@ -235,18 +184,18 @@ object Magnolia { $typeName, true, false, new $scalaPkg.Array(0), _ => ${genericType.typeSymbol.asClass.module}) ) """ - Some(Typeclass(genericType, impl)) + Some(impl) } else if (isCaseClass || isValueClass) { val caseClassParameters = genericType.decls.collect { case m: MethodSymbol if m.isCaseAccessor || (isValueClass && m.isParamAccessor) => m.asMethod } - final case class CaseParam(sym: c.universe.MethodSymbol, + case class CaseParam(sym: MethodSymbol, repeated: Boolean, - typeclass: c.Tree, - paramType: c.Type, - ref: c.TermName) + typeclass: Tree, + paramType: Type, + ref: TermName) val caseParamsReversed = caseClassParameters.foldLeft[List[CaseParam]](Nil) { (acc, param) => @@ -260,23 +209,18 @@ object Magnolia { false -> tpe } - val predefinedRef = acc.find(_.paramType == paramType) - - val caseParamOpt = predefinedRef.map { backRef => - CaseParam(param, repeated, q"()", paramType, backRef.ref) :: acc - } - - caseParamOpt.getOrElse { - val derivedImplicit = - recurse(ProductType(paramName, genericType.toString), genericType, assignedName) { - typeclassTree(Some(paramName), paramType, typeConstructor, assignedName) - }.getOrElse( - c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType") - ) + acc.find(_.paramType =:= paramType).fold { + val path = ProductType(paramName, genericType.toString) + val frame = stack.Frame(path, resultType, assignedName) + val derivedImplicit = stack.recurse(frame, appliedType(typeConstructor, paramType)) { + typeclassTree(paramType, typeConstructor) + } val ref = TermName(c.freshName("paramTypeclass")) - val assigned = q"""val $ref = $derivedImplicit""" + val assigned = q"""lazy val $ref = $derivedImplicit""" CaseParam(param, repeated, assigned, paramType, ref) :: acc + } { backRef => + CaseParam(param, repeated, q"()", paramType, backRef.ref) :: acc } } @@ -288,7 +232,7 @@ object Magnolia { val preAssignments = caseParams.map(_.typeclass) val defaults = if (!isValueClass) { - val companionRef = GlobalUtil.patchedCompanionRef(c)(genericType) + val companionRef = GlobalUtil.patchedCompanionRef(c)(genericType.dealias) val companionSym = companionRef.symbol.asModule.info // If a companion object is defined with alternative apply methods @@ -318,10 +262,7 @@ object Magnolia { )""" } - Some( - Typeclass( - genericType, - q"""{ + Some(q"""{ ..$preAssignments val $paramsVal: $scalaPkg.Array[$magnoliaPkg.Param[$typeConstructor, $genericType]] = new $scalaPkg.Array(${assignments.length}) @@ -346,7 +287,6 @@ object Magnolia { } })})) }""") - ) } else if (isSealedTrait) { val genericSubtypes = classType.get.knownDirectSubclasses.to[List] val subtypes = genericSubtypes.map { sub => @@ -362,18 +302,18 @@ object Magnolia { if (subtypes.isEmpty) { c.info(c.enclosingPosition, s"magnolia: could not find any direct subtypes of $typeSymbol", - true) + force = true) c.abort(c.enclosingPosition, "") } val subtypesVal: TermName = TermName(c.freshName("subtypes")) - val typeclasses = subtypes.map { searchType => - recurse(CoproductType(genericType.toString), genericType, assignedName) { - (searchType, typeclassTree(None, searchType, typeConstructor, assignedName)) - }.getOrElse { - c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType") + val typeclasses = for (subType <- subtypes) yield { + val path = CoproductType(genericType.toString) + val frame = stack.Frame(path, resultType, assignedName) + subType -> stack.recurse(frame, appliedType(typeConstructor, subType)) { + typeclassTree(subType, typeConstructor) } } @@ -386,11 +326,8 @@ object Magnolia { (t: $genericType) => t.asInstanceOf[$typ] )""" } - - Some { - Typeclass( - genericType, - q"""{ + + Some(q"""{ val $subtypesVal: $scalaPkg.Array[$magnoliaPkg.Subtype[$typeConstructor, $genericType]] = new $scalaPkg.Array(${assignments.size}) @@ -402,64 +339,37 @@ object Magnolia { $typeName, $subtypesVal: $scalaPkg.Array[$magnoliaPkg.Subtype[$typeConstructor, $genericType]]) ): $resultType - }""" - ) - } + }""") } else None - result.map { - case Typeclass(t, r) => - Typeclass(t, q"""{ - def $assignedName: $resultType = $r - $assignedName - }""") - } + for (term <- result) yield q"""{ + lazy val $assignedName: $resultType = $term + $assignedName + }""" } val genericType: Type = weakTypeOf[T] - - val currentStack: Stack = - recursionStack.getOrElse(c.enclosingPosition, Stack(Map(), List(), List())) - - val directlyReentrant = currentStack.frames.headOption.exists(_.genericType == genericType) - + val searchType = appliedType(typeConstructor, genericType) + val directlyReentrant = stack.top.exists(_.searchType =:= searchType) if (directlyReentrant) throw DirectlyReentrantException() - currentStack.errors.foreach { error => - if (!emittedErrors.contains(error)) { - emittedErrors += error - val trace = error.path.mkString("\n in ", "\n in ", "\n \n") - - val msg = s"magnolia: could not derive $typeConstructor instance for type " + - s"${error.genericType}" - - c.info(c.enclosingPosition, msg + trace, true) - } + val result = stack.find(searchType).map { enclosingRef => + q"$magnoliaPkg.Deferred[$searchType](${enclosingRef.toString})" + }.orElse { + directInferImplicit(genericType, typeConstructor) } - val result: Option[Tree] = if (currentStack.frames.nonEmpty) { - findType(genericType) match { - case None => - directInferImplicit(genericType, typeConstructor).map(_.tree) - case Some(enclosingRef) => - val methodAsString = enclosingRef.toString - val searchType = appliedType(typeConstructor, genericType) - Some(q"$magnoliaPkg.Deferred[$searchType]($methodAsString)") - } - } else directInferImplicit(genericType, typeConstructor).map(_.tree) - - if (currentStack.frames.isEmpty) recursionStack = ListMap() - - val dereferencedResult = result.map { tree => - if (debug.isDefined && genericType.toString.contains(debug.get)) { - c.echo(c.enclosingPosition, s"Magnolia macro expansion for $genericType") - c.echo(NoPosition, s"... = ${showCode(tree)}\n\n") - } - if (currentStack.frames.isEmpty) c.untypecheck(removeDeferred.transform(tree)) else tree + for (tree <- result) if (debug.isDefined && genericType.toString.contains(debug.get)) { + c.echo(c.enclosingPosition, s"Magnolia macro expansion for $genericType") + c.echo(NoPosition, s"... = ${showCode(tree)}\n\n") } + val dereferencedResult = if (stack.nonEmpty) result + else for (tree <- result) yield c.untypecheck(removeDeferred.transform(tree)) + dereferencedResult.getOrElse { - c.abort(c.enclosingPosition, s"magnolia: could not infer typeclass for type $genericType") + c.abort(c.enclosingPosition, + s"magnolia: could not infer $prefixName.Typeclass for type $genericType") } } @@ -483,7 +393,7 @@ object Magnolia { * should not be called directly from users' code. */ def param[Tc[_], T, P](name: String, isRepeated: Boolean, - typeclassParam: Tc[P], + typeclassParam: => Tc[P], defaultVal: => Option[P], deref: T => P): Param[Tc, T] = new Param[Tc, T] { type PType = P @@ -521,35 +431,61 @@ private[magnolia] object CompileTimeState { final case class ProductType(paramName: String, typeName: String) extends TypePath(s"parameter '$paramName' of product type $typeName") - final case class ChainedImplicit(typeName: String) - extends TypePath(s"chained implicit of type $typeName") + final case class ChainedImplicit(typeClassName: String, typeName: String) + extends TypePath(s"chained implicit $typeClassName for type $typeName") - final case class ImplicitNotFound(genericType: String, path: List[TypePath]) + final class Stack[C <: whitebox.Context] { + private var frames = List.empty[Frame] + private val cache = mutable.Map.empty[C#Type, C#Tree] - final case class Stack(cache: Map[whitebox.Context#Type, Option[whitebox.Context#Tree]], - frames: List[Frame], - errors: List[ImplicitNotFound]) { + def isEmpty: Boolean = frames.isEmpty + def nonEmpty: Boolean = frames.nonEmpty + def top: Option[Frame] = frames.headOption + def pop(): Unit = frames = frames drop 1 + def push(frame: Frame): Unit = frames ::= frame - def lookup(c: whitebox.Context)(t: c.Type)(orElse: => Option[c.Tree]): (Option[c.Tree], Stack) = - if (cache.contains(t)) { - (cache(t).asInstanceOf[Option[c.Tree]], this) - } else { - val value = orElse - (value, copy(cache.updated(t, value))) - } + def clear(): Unit = { + frames = Nil + cache.clear() + } + + def find(searchType: C#Type): Option[C#TermName] = frames.collectFirst { + case Frame(_, tpe, term) if tpe =:= searchType => term + } + + def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => T): T = { + push(frame) + val result = cache.getOrElseUpdate(searchType, fn) + pop() + result.asInstanceOf[T] + } - def push(path: TypePath, key: whitebox.Context#Type, value: whitebox.Context#TermName): Stack = - Stack(cache, Frame(path, key, value) :: frames, errors) + def trace: List[TypePath] = frames.drop(1).foldLeft[(C#Type, List[TypePath])]((null, Nil)) { + case ((_, Nil), frame) => + (frame.searchType, frame.path :: Nil) + case (continue @ (tpe, acc), frame) => + if (tpe =:= frame.searchType) continue + else (frame.searchType, frame.path :: acc) + }._2.reverse - def pop(): Stack = Stack(cache, frames.tail, errors) - } + override def toString: String = + frames.mkString("magnolia stack:\n", "\n", "\n") - final case class Frame(path: TypePath, - genericType: whitebox.Context#Type, - term: whitebox.Context#TermName) { - def termName(c: whitebox.Context): c.TermName = term.asInstanceOf[c.TermName] + final case class Frame(path: TypePath, searchType: C#Type, term: C#TermName) } - var recursionStack: ListMap[api.Position, Stack] = ListMap() - var emittedErrors: Set[ImplicitNotFound] = Set() + object Stack { + private val global = new Stack[whitebox.Context] + private val workSet = mutable.Set.empty[whitebox.Context#Symbol] + + def withContext(c: whitebox.Context)(fn: Stack[c.type] => c.Tree): c.Tree = { + workSet += c.macroApplication.symbol + val depth = c.enclosingMacros.count(m => workSet(m.macroApplication.symbol)) + try fn(global.asInstanceOf[Stack[c.type]]) + finally if (depth <= 1) { + global.clear() + workSet.clear() + } + } + } } -- cgit v1.2.3