From b2c4b2c345a8ac82bbf3240cc41b248b07924705 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Sat, 14 Apr 2018 23:23:13 +0200 Subject: Add Singleton bound on Stack context This improves type inference for the (path-dependent) `Tree`, `Type`, etc. used with a concrete `c`. --- core/shared/src/main/scala/magnolia.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/shared/src/main/scala/magnolia.scala b/core/shared/src/main/scala/magnolia.scala index b61b7d7..a9e231a 100644 --- a/core/shared/src/main/scala/magnolia.scala +++ b/core/shared/src/main/scala/magnolia.scala @@ -155,7 +155,7 @@ object Magnolia { .filterNot(_.isEmpty) .orElse(directInferImplicit(genericType, typeConstructor)) .getOrElse { - val missingType = stack.top.fold(searchType)(_.searchType.asInstanceOf[Type]) + val missingType = stack.top.fold(searchType)(_.searchType) val typeClassName = s"${missingType.typeSymbol.name.decodedName}.Typeclass" val genericType = missingType.typeArgs.head val trace = stack.trace.mkString(" in ", "\n in ", "\n") @@ -295,7 +295,7 @@ object Magnolia { } val assignments = caseParams.zip(defaults).zip(annotations).zipWithIndex.map { - case (((CaseParam(param, repeated, typeclass, paramType, ref), defaultVal), annList), idx) => + case (((CaseParam(param, repeated, _, paramType, ref), defaultVal), annList), idx) => q"""$paramsVal($idx) = $magnoliaPkg.Magnolia.param[$typeConstructor, $genericType, $paramType]( ${param.name.decodedName.toString}, @@ -492,7 +492,7 @@ private[magnolia] object CompileTimeState { final case class ChainedImplicit(typeClassName: String, typeName: String) extends TypePath(s"chained implicit $typeClassName for type $typeName") - final class Stack[C <: whitebox.Context] { + final class Stack[C <: whitebox.Context with Singleton] { private var frames = List.empty[Frame] private val cache = mutable.Map.empty[C#Type, C#Tree] @@ -511,11 +511,11 @@ private[magnolia] object CompileTimeState { case Frame(_, tpe, term) if tpe =:= searchType => term } - def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => T): T = { + def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => C#Tree): C#Tree = { push(frame) val result = cache.getOrElseUpdate(searchType, fn) pop() - result.asInstanceOf[T] + result } def trace: List[TypePath] = @@ -538,7 +538,9 @@ private[magnolia] object CompileTimeState { } object Stack { - private val global = new Stack[whitebox.Context] + // Cheating to satisfy Singleton bound (which improves type inference). + private val dummyContext: whitebox.Context = null + private val global = new Stack[dummyContext.type] private val workSet = mutable.Set.empty[whitebox.Context#Symbol] def withContext(c: whitebox.Context)(fn: Stack[c.type] => c.Tree): c.Tree = { -- cgit v1.2.3