diff options
Diffstat (limited to 'core/src/main/scala/magnolia.scala')
-rw-r--r-- | core/src/main/scala/magnolia.scala | 56 |
1 files changed, 34 insertions, 22 deletions
diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala index cabe88e..4d3d3d3 100644 --- a/core/src/main/scala/magnolia.scala +++ b/core/src/main/scala/magnolia.scala @@ -45,7 +45,7 @@ object Magnolia { recursionStack = recursionStack.updated( c.enclosingPosition, recursionStack.get(c.enclosingPosition).map(_.push(path, key, value)).getOrElse( - Stack(List(Frame(path, key, value)), Nil)) + Stack(Map(), List(Frame(path, key, value)), Nil)) ) try Some(fn) catch { case e: Exception => None } finally { @@ -73,17 +73,21 @@ object Magnolia { val methodAsString = methodName.encodedName.toString q"_root_.magnolia.Deferred.apply[$searchType]($methodAsString)" }.orElse { - scala.util.Try { - val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase - val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) - recurse(ChainedImplicit(genericType.toString), genericType, assignedName) { - val inferredImplicit = c.inferImplicitValue(searchType, false, false) - q"""{ - def $assignedName: $searchType = $inferredImplicit - $assignedName - }""" - }.get - }.toOption.orElse(directInferImplicit(genericType, typeConstructor)) + val (inferredImplicit, newStack) = recursionStack(c.enclosingPosition).lookup(c)(searchType) { + scala.util.Try { + val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase + val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) + recurse(ChainedImplicit(genericType.toString), genericType, assignedName) { + val inferredImplicit = c.inferImplicitValue(searchType, false, false) + q"""{ + def $assignedName: $searchType = $inferredImplicit + $assignedName + }""" + }.get + }.toOption.orElse(directInferImplicit(genericType, typeConstructor)) + } + recursionStack = recursionStack.updated(c.enclosingPosition, newStack) + inferredImplicit }.getOrElse { val currentStack: Stack = recursionStack(c.enclosingPosition) @@ -111,8 +115,6 @@ object Magnolia { val resultType = appliedType(typeConstructor, genericType) - println(s"Deriving $genericType") - // FIXME: Handle AnyVals if(isCaseObject) { val termSym = genericType.typeSymbol.companionSymbol @@ -162,6 +164,7 @@ object Magnolia { def label: _root_.java.lang.String = $label def dereference(param: ${genericType}): ${paramType} = param.${TermName(label)} }""" + } val constructor = q"""new $genericType(..${callables.zip(implicits).map { case (call, imp) => @@ -236,7 +239,7 @@ object Magnolia { val genericType: Type = weakTypeOf[T] val currentStack: Stack = - recursionStack.get(c.enclosingPosition).getOrElse(Stack(List(), List())) + recursionStack.get(c.enclosingPosition).getOrElse(Stack(Map(), List(), List())) val directlyReentrant = Some(genericType) == currentStack.frames.headOption.map(_.genericType) @@ -268,10 +271,11 @@ object Magnolia { if(currentStack.frames.isEmpty) recursionStack = ListMap() result.map { tree => - val out = if(currentStack.frames.isEmpty) c.untypecheck(removeDeferred.transform(tree)) - else tree - println(out) - out + if(currentStack.frames.isEmpty) { + val out = c.untypecheck(removeDeferred.transform(tree)) + //println(out) + out + } else tree }.getOrElse { c.abort(c.enclosingPosition, s"magnolia: could not infer typeclass for type $genericType") } @@ -297,13 +301,21 @@ private[magnolia] object CompileTimeState { case class ImplicitNotFound(genericType: String, path: List[TypePath]) - case class Stack(frames: List[Frame], errors: List[ImplicitNotFound]) { + case class Stack(cache: Map[whitebox.Context#Type, Option[whitebox.Context#Tree]], frames: List[Frame], errors: List[ImplicitNotFound]) { + def lookup(c: whitebox.Context)(t: c.Type)(orElse: => Option[c.Tree]): (Option[c.Tree], Stack) = + if(cache.contains(t)) { + (cache(t).asInstanceOf[Option[c.Tree]], this) + } else { + val value = orElse + (value, copy(cache.updated(t, value))) + } + def push(path: TypePath, key: whitebox.Context#Type, value: whitebox.Context#TermName): Stack = - Stack(Frame(path, key, value) :: frames, errors) + Stack(cache, Frame(path, key, value) :: frames, errors) - def pop(): Stack = Stack(frames.tail, errors) + def pop(): Stack = Stack(cache, frames.tail, errors) } case class Frame(path: TypePath, genericType: whitebox.Context#Type, |