diff options
-rw-r--r-- | core/shared/src/main/scala/magnolia.scala | 80 | ||||
-rw-r--r-- | tests/src/main/scala/tests.scala | 13 |
2 files changed, 37 insertions, 56 deletions
diff --git a/core/shared/src/main/scala/magnolia.scala b/core/shared/src/main/scala/magnolia.scala index 9190c1d..1ec6649 100644 --- a/core/shared/src/main/scala/magnolia.scala +++ b/core/shared/src/main/scala/magnolia.scala @@ -14,6 +14,8 @@ */ package magnolia +import scala.annotation.compileTimeOnly +import scala.collection.breakOut import scala.collection.mutable import scala.language.existentials import scala.language.higherKinds @@ -95,42 +97,34 @@ object Magnolia { val prefixObject = prefixType.typeSymbol val prefixName = prefixObject.name.decodedName + def error(msg: String) = c.abort(c.enclosingPosition, msg) + val typeDefs = prefixType.baseClasses.flatMap { cls => cls.asType.toType.decls.filter(_.isType).find(_.name.toString == "Typeclass").map { tpe => tpe.asType.toType.asSeenFrom(prefixType, cls) } } - val typeConstructor = typeDefs.headOption.fold { - c.abort( - c.enclosingPosition, - s"magnolia: the derivation $prefixObject does not define the Typeclass type constructor" - ) - }(_.typeConstructor) + val typeConstructor = typeDefs.headOption.fold( + error(s"magnolia: the derivation $prefixObject does not define the Typeclass type constructor") + )(_.typeConstructor) def checkMethod(termName: String, category: String, expected: String): Unit = { val term = TermName(termName) val combineClass = c.prefix.tree.tpe.baseClasses - .find { cls => - cls.asType.toType.decl(term) != NoSymbol - } - .getOrElse { - c.abort( - c.enclosingPosition, - s"magnolia: the method `$termName` must be defined on the derivation $prefixObject to derive typeclasses for $category" - ) - } + .find(cls => cls.asType.toType.decl(term) != NoSymbol) + .getOrElse(error(s"magnolia: the method `$termName` must be defined on the derivation $prefixObject to derive typeclasses for $category")) + val firstParamBlock = combineClass.asType.toType.decl(term).asTerm.asMethod.paramLists.head if (firstParamBlock.lengthCompare(1) != 0) - c.abort(c.enclosingPosition, - s"magnolia: the method `combine` should take a single parameter of type $expected") + error(s"magnolia: the method `combine` should take a single parameter of type $expected") } // FIXME: Only run these methods if they're used, particularly `dispatch` checkMethod("combine", "case classes", "CaseClass[Typeclass, _]") checkMethod("dispatch", "sealed traits", "SealedTrait[Typeclass, _]") - val removeDeferred = new Transformer { + val expandDeferred = new Transformer { override def transform(tree: Tree) = tree match { case q"$magnolia.Deferred.apply[$_](${Literal(Constant(method: String))})" if magnolia.symbol == magnoliaPkg => @@ -155,12 +149,11 @@ object Magnolia { .filterNot(_.isEmpty) .orElse(directInferImplicit(genericType, typeConstructor)) .getOrElse { - val missingType = stack.top.fold(searchType)(_.searchType.asInstanceOf[Type]) + val missingType = stack.top.fold(searchType)(_.searchType) val typeClassName = s"${missingType.typeSymbol.name.decodedName}.Typeclass" val genericType = missingType.typeArgs.head val trace = stack.trace.mkString(" in ", "\n in ", "\n") - c.abort(c.enclosingPosition, - s"magnolia: could not find $typeClassName for type $genericType\n$trace") + error(s"magnolia: could not find $typeClassName for type $genericType\n$trace") } } } @@ -295,7 +288,7 @@ object Magnolia { } val assignments = caseParams.zip(defaults).zip(annotations).zipWithIndex.map { - case (((CaseParam(param, repeated, typeclass, paramType, ref), defaultVal), annList), idx) => + case (((CaseParam(param, repeated, _, paramType, ref), defaultVal), annList), idx) => q"""$paramsVal($idx) = $magnoliaPkg.Magnolia.param[$typeConstructor, $genericType, $paramType]( ${param.name.decodedName.toString}, @@ -349,7 +342,7 @@ object Magnolia { s"magnolia: could not find any direct subtypes of $typeSymbol", force = true) - c.abort(c.enclosingPosition, "") + error("") } val subtypesVal: TermName = TermName(c.freshName("subtypes")) @@ -401,12 +394,8 @@ object Magnolia { val result = stack .find(searchType) - .map { enclosingRef => - q"$magnoliaPkg.Deferred[$searchType](${enclosingRef.toString})" - } - .orElse { - directInferImplicit(genericType, typeConstructor) - } + .map(enclosingRef => q"$magnoliaPkg.Deferred[$searchType](${enclosingRef.toString})") + .orElse(directInferImplicit(genericType, typeConstructor)) for (tree <- result) if (debug.isDefined && genericType.toString.contains(debug.get)) { c.echo(c.enclosingPosition, s"Magnolia macro expansion for $genericType") @@ -415,11 +404,10 @@ object Magnolia { val dereferencedResult = if (stack.nonEmpty) result - else for (tree <- result) yield c.untypecheck(removeDeferred.transform(tree)) + else for (tree <- result) yield c.untypecheck(expandDeferred.transform(tree)) dereferencedResult.getOrElse { - c.abort(c.enclosingPosition, - s"magnolia: could not infer $prefixName.Typeclass for type $genericType") + error(s"magnolia: could not infer $prefixName.Typeclass for type $genericType") } } @@ -479,7 +467,8 @@ object Magnolia { private[magnolia] final case class DirectlyReentrantException() extends Exception("attempt to recurse directly") -private[magnolia] object Deferred { def apply[T](method: String): T = ??? } +@compileTimeOnly("magnolia.Deferred is used for derivation of recursive typeclasses") +object Deferred { def apply[T](method: String): T = ??? } private[magnolia] object CompileTimeState { @@ -492,7 +481,7 @@ private[magnolia] object CompileTimeState { final case class ChainedImplicit(typeClassName: String, typeName: String) extends TypePath(s"chained implicit $typeClassName for type $typeName") - final class Stack[C <: whitebox.Context] { + final class Stack[C <: whitebox.Context with Singleton] { private var frames = List.empty[Frame] private val cache = mutable.Map.empty[C#Type, C#Tree] @@ -511,25 +500,18 @@ private[magnolia] object CompileTimeState { case Frame(_, tpe, term) if tpe =:= searchType => term } - def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => T): T = { + def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => C#Tree): C#Tree = { push(frame) val result = cache.getOrElseUpdate(searchType, fn) pop() - result.asInstanceOf[T] + result } def trace: List[TypePath] = - frames - .drop(1) - .foldLeft[(C#Type, List[TypePath])]((null, Nil)) { - case ((_, Nil), frame) => - (frame.searchType, frame.path :: Nil) - case (continue @ (tpe, acc), frame) => - if (tpe =:= frame.searchType) continue - else (frame.searchType, frame.path :: acc) - } - ._2 - .reverse + (frames.drop(1), frames).zipped.collect { + case (Frame(path, tp1, _), Frame(_, tp2, _)) + if !(tp1 =:= tp2) => path + } (breakOut) override def toString: String = frames.mkString("magnolia stack:\n", "\n", "\n") @@ -538,7 +520,9 @@ private[magnolia] object CompileTimeState { } object Stack { - private val global = new Stack[whitebox.Context] + // Cheating to satisfy Singleton bound (which improves type inference). + private val dummyContext: whitebox.Context = null + private val global = new Stack[dummyContext.type] private val workSet = mutable.Set.empty[whitebox.Context#Symbol] def withContext(c: whitebox.Context)(fn: Stack[c.type] => c.Tree): c.Tree = { diff --git a/tests/src/main/scala/tests.scala b/tests/src/main/scala/tests.scala index fc3a012..c3b90cc 100644 --- a/tests/src/main/scala/tests.scala +++ b/tests/src/main/scala/tests.scala @@ -43,10 +43,10 @@ class Length(val value: Int) extends AnyVal case class FruitBasket(fruits: Fruit*) case class Lunchbox(fruit: Fruit, drink: String) +case class Fruit(name: String) object Fruit { implicit val showFruit: Show[String, Fruit] = (f: Fruit) => f.name } -case class Fruit(name: String) case class Item(name: String, quantity: Int = 1, price: Int) @@ -116,19 +116,16 @@ object Tests extends TestApp { }.assert(_ == "Address(line1=Home,occupant=nobody)") test("even low-priority implicit beats Magnolia for nested case") { - import Show.gen implicitly[Show[String, Lunchbox]].show(Lunchbox(Fruit("apple"), "lemonade")) }.assert(_ == "Lunchbox(fruit=apple,drink=lemonade)") - test("low-priority implicit does not beat Magnolia when not nested") { - import Show.gen + test("low-priority implicit beats Magnolia when not nested") { implicitly[Show[String, Fruit]].show(Fruit("apple")) - }.assert(_ == "Fruit(name=apple)") + }.assert(_ == "apple") - test("low-priority implicit does not beat Magnolia when chained") { - import Show.gen + test("low-priority implicit beats Magnolia when chained") { implicitly[Show[String, FruitBasket]].show(FruitBasket(Fruit("apple"), Fruit("banana"))) - }.assert(_ == "FruitBasket(fruits=[Fruit(name=apple),Fruit(name=banana)])") + }.assert(_ == "FruitBasket(fruits=[apple,banana])") test("typeclass implicit scope has lower priority than ADT implicit scope") { implicitly[Show[String, Fruit]].show(Fruit("apple")) |