From 092075d1eeca2dde829cba67bf3a96ef62738b15 Mon Sep 17 00:00:00 2001 From: Jon Pretty Date: Sun, 11 Jun 2017 10:25:55 +0200 Subject: More typesafety in the macro --- core/src/main/scala/magnolia.scala | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) (limited to 'core/src/main') diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala index 815b613..e1aacde 100644 --- a/core/src/main/scala/magnolia.scala +++ b/core/src/main/scala/magnolia.scala @@ -18,20 +18,19 @@ abstract class MagnoliaMacro(val c: whitebox.Context) { protected def transformation(c: whitebox.Context): Transformation[c.type] - private def findType(key: c.universe.Type): Option[c.universe.TermName] = - recursionStack(c.enclosingPosition).find(_._1 == key).map(_._2.asInstanceOf[c.universe.TermName]) + private def findType(key: c.universe.Type): Option[c.TermName] = + recursionStack(c.enclosingPosition).find(_.genericType == key).map(_.termName(c)) - private def recurse[T](key: c.universe.Type, value: c.universe.TermName)(fn: => T): Option[T] = { + private def recurse[T](key: c.universe.Type, value: c.TermName)(fn: => T): Option[T] = { recursionStack = recursionStack.updated( c.enclosingPosition, - recursionStack.get(c.enclosingPosition).map { m => - m ::: List((key, value)) - }.getOrElse(List(key -> value)) + recursionStack.get(c.enclosingPosition).map(Frame(key, value) :: _).getOrElse( + List(Frame(key, value))) ) try Some(fn) catch { case e: Exception => None } finally { recursionStack = recursionStack.updated(c.enclosingPosition, - recursionStack(c.enclosingPosition).init) + recursionStack(c.enclosingPosition).tail) } } @@ -47,7 +46,7 @@ abstract class MagnoliaMacro(val c: whitebox.Context) { private def getImplicit(genericType: c.universe.Type, typeConstructor: c.universe.Type, - assignedName: c.universe.TermName): c.Tree = { + assignedName: c.TermName): c.Tree = { findType(genericType).map { methodName => val methodAsString = methodName.encodedName.toString @@ -183,9 +182,12 @@ private[magnolia] object Lazy { def apply[T](method: String): T = ??? } private[magnolia] object CompileTimeState { - private[magnolia] var recursionStack: Map[api.Position, List[ - (c.universe.Type, c.universe.TermName) forSome { val c: whitebox.Context } - ]] = Map() + case class Frame[C <: whitebox.Context](val genericType: C#Type, val term: C#TermName) { + def termName(c: whitebox.Context): c.TermName = term.asInstanceOf[c.TermName] + } + + private[magnolia] var recursionStack: Map[api.Position, List[Frame[_ <: whitebox.Context]]] = + Map() private[magnolia] var lastSearchType: Option[Universe#Type] = None } -- cgit v1.2.3