diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/TransformUtils.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/TransformUtils.scala | 200 |
1 files changed, 142 insertions, 58 deletions
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 0b8cd00..bd7093f 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -5,7 +5,7 @@ package scala.async.internal import scala.reflect.macros.Context import reflect.ClassTag -import scala.reflect.macros.runtime.AbortMacroException +import scala.collection.immutable.ListMap /** * Utilities used in both `ExprBuilder` and `AnfTransform`. @@ -13,7 +13,9 @@ import scala.reflect.macros.runtime.AbortMacroException private[async] trait TransformUtils { self: AsyncMacro => - import global._ + import c.universe._ + import c.internal._ + import decorators._ object name { val resume = newTermName("resume") @@ -31,14 +33,82 @@ private[async] trait TransformUtils { val tr = newTermName("tr") val t = newTermName("throwable") - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + def fresh(name: TermName): TermName = c.freshName(name) - def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString + def fresh(name: String): String = c.freshName(name) } def isAwait(fun: Tree) = fun.symbol == defn.Async_await + // Copy pasted from TreeInfo in the compiler. + // Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not + // sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match + // constructor invocations. + class Applied(val tree: Tree) { + /** The tree stripped of the possibly nested applications. + * The original tree if it's not an application. + */ + def callee: Tree = { + def loop(tree: Tree): Tree = tree match { + case Apply(fn, _) => loop(fn) + case tree => tree + } + loop(tree) + } + + /** The `callee` unwrapped from type applications. + * The original `callee` if it's not a type application. + */ + def core: Tree = callee match { + case TypeApply(fn, _) => fn + case AppliedTypeTree(fn, _) => fn + case tree => tree + } + + /** The type arguments of the `callee`. + * `Nil` if the `callee` is not a type application. + */ + def targs: List[Tree] = callee match { + case TypeApply(_, args) => args + case AppliedTypeTree(_, args) => args + case _ => Nil + } + + /** (Possibly multiple lists of) value arguments of an application. + * `Nil` if the `callee` is not an application. + */ + def argss: List[List[Tree]] = { + def loop(tree: Tree): List[List[Tree]] = tree match { + case Apply(fn, args) => loop(fn) :+ args + case _ => Nil + } + loop(tree) + } + } + + /** Returns a wrapper that knows how to destructure and analyze applications. + */ + def dissectApplied(tree: Tree) = new Applied(tree) + + /** Destructures applications into important subparts described in `Applied` class, + * namely into: core, targs and argss (in the specified order). + * + * Trees which are not applications are also accepted. Their callee and core will + * be equal to the input, while targs and argss will be Nil. + * + * The provided extractors don't expose all the API of the `Applied` class. + * For advanced use, call `dissectApplied` explicitly and use its methods instead of pattern matching. + */ + object Applied { + def apply(tree: Tree): Applied = new Applied(tree) + + def unapply(applied: Applied): Option[(Tree, List[Tree], List[List[Tree]])] = + Some((applied.core, applied.targs, applied.argss)) + + def unapply(tree: Tree): Option[(Tree, List[Tree], List[List[Tree]])] = + unapply(dissectApplied(tree)) + } private lazy val Boolean_ShortCircuits: Set[Symbol] = { import definitions.BooleanClass def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) @@ -52,7 +122,7 @@ private[async] trait TransformUtils { else if (fun.tpe == null) (x, y) => false else { val paramss = fun.tpe.paramss - val byNamess = paramss.map(_.map(_.isByNameParam)) + val byNamess = paramss.map(_.map(_.asTerm.isByNameParam)) (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) } } @@ -62,11 +132,9 @@ private[async] trait TransformUtils { (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") } - def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) - object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) } def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { @@ -82,11 +150,7 @@ private[async] trait TransformUtils { } val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") - val Async_await = asyncBase.awaitMethod(global)(macroApplication.symbol).ensuring(_ != NoSymbol) - } - - def isSafeToInline(tree: Tree) = { - treeInfo.isExprSafeToInline(tree) + val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) } // `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops; @@ -97,11 +161,17 @@ private[async] trait TransformUtils { case ld: LabelDef => ld.symbol }.toSet t.exists { - case rt: RefTree => rt.symbol != null && rt.symbol.isLabel && !(labelDefs contains rt.symbol) + case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) case _ => false } } + private def isLabel(sym: Symbol): Boolean = { + val LABEL = 1L << 17 // not in the public reflection API. + (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L + } + + /** Map a list of arguments to: * - A list of argument Trees * - A list of auxillary results. @@ -191,7 +261,7 @@ private[async] trait TransformUtils { case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + case q"$fun[..$targs](...$argss)" if argss.nonEmpty => val isInByName = isByName(fun) for ((args, i) <- argss.zipWithIndex) { for ((arg, j) <- args.zipWithIndex) { @@ -205,63 +275,76 @@ private[async] trait TransformUtils { } } - def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) - - abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { - currentOwner = callSiteTyper.context.owner - curTree = EmptyTree - - def currOwner: Symbol = currentOwner - - localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + def transformAt(tree: Tree)(f: PartialFunction[Tree, (TypingTransformApi => Tree)]) = { + typingTransform(tree)((tree, api) => { + if (f.isDefinedAt(tree)) f(tree)(api) + else api.default(tree) + }) } - def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { - object trans extends MacroTypingTransformer { - override def transform(tree: Tree): Tree = { - if (f.isDefinedAt(tree)) { - f(tree)(localTyper.context) - } else super.transform(tree) - } + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + + // Attributed version of `TreeGen#mkCastPreservingAnnotations` + def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { + atPos(tree.pos) { + val casted = c.typecheck(gen.mkCast(tree, uncheckedBounds(withoutAnnotations(tp)).dealias)) + Typed(casted, TypeTree(tp)).setType(tp) } - trans.transform(tree) } - def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { - new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) - tree + def deconst(tp: Type): Type = tp match { + case AnnotatedType(anns, underlying) => annotatedType(anns, deconst(underlying)) + case ExistentialType(quants, underlying) => existentialType(quants, deconst(underlying)) + case ConstantType(value) => deconst(value.tpe) + case _ => tp } - class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) - extends ChangeOwnerTraverser(oldowner, newowner) { + def withAnnotation(tp: Type, ann: Annotation): Type = withAnnotations(tp, List(ann)) - override def traverse(tree: Tree) { - tree match { - case _: DefTree => change(tree.symbol.moduleClass) - case _ => - } - super.traverse(tree) - } + def withAnnotations(tp: Type, anns: List[Annotation]): Type = tp match { + case AnnotatedType(existingAnns, underlying) => annotatedType(anns ::: existingAnns, underlying) + case ExistentialType(quants, underlying) => existentialType(quants, withAnnotations(underlying, anns)) + case _ => annotatedType(anns, tp) } - def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = - as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + def withoutAnnotations(tp: Type): Type = tp match { + case AnnotatedType(anns, underlying) => withoutAnnotations(underlying) + case ExistentialType(quants, underlying) => existentialType(quants, withoutAnnotations(underlying)) + case _ => tp + } - // Attributed version of `TreeGen#mkCastPreservingAnnotations` - def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { - atPos(tree.pos) { - val casted = gen.mkAttributedCast(tree, uncheckedBounds(tp.withoutAnnotations).dealias) - Typed(casted, TypeTree(tp)).setType(tp) - } + def tpe(sym: Symbol): Type = { + if (sym.isType) sym.asType.toType + else sym.info } + def thisType(sym: Symbol): Type = { + if (sym.isClass) sym.asClass.thisPrefix + else NoPrefix + } + + private def derivedValueClassUnbox(cls: Symbol) = + (cls.info.decls.find(sym => sym.isMethod && sym.asTerm.isParamAccessor) getOrElse NoSymbol) + def mkZero(tp: Type): Tree = { - if (tp.typeSymbol.isDerivedValueClass) { - val argZero = mkZero(tp.memberType(tp.typeSymbol.derivedValueClassUnbox).resultType) + val tpSym = tp.typeSymbol + if (tpSym.isClass && tpSym.asClass.isDerivedValueClass) { + val argZero = mkZero(derivedValueClassUnbox(tpSym).infoIn(tp).resultType) + val baseType = tp.baseType(tpSym) // use base type here to dealias / strip phantom "tagged types" etc. + + // By explicitly attributing the types and symbols here, we subvert privacy. + // Otherwise, ticket86PrivateValueClass would fail. + + // Approximately: + // q"new ${valueClass}[$..targs](argZero)" val target: Tree = gen.mkAttributedSelect( - typer.typedPos(macroPos)( - New(TypeTree(tp.baseType(tp.typeSymbol)))), tp.typeSymbol.primaryConstructor) + c.typecheck(atMacroPos( + New(TypeTree(baseType)))), tpSym.asClass.primaryConstructor) + val zero = gen.mkMethodCall(target, argZero :: Nil) + + // restore the original type which we might otherwise have weakened with `baseType` above gen.mkCast(zero, tp) } else { gen.mkZero(tp) @@ -271,11 +354,12 @@ private[async] trait TransformUtils { // ===================================== // Copy/Pasted from Scala 2.10.3. See SI-7694. private lazy val UncheckedBoundsClass = { - global.rootMirror.getClassIfDefined("scala.reflect.internal.annotations.uncheckedBounds") + try c.mirror.staticClass("scala.reflect.internal.annotations.uncheckedBounds") + catch { case _: ScalaReflectionException => NoSymbol } } final def uncheckedBounds(tp: Type): Type = { if (tp.typeArgs.isEmpty || UncheckedBoundsClass == NoSymbol) tp - else tp.withAnnotation(AnnotationInfo marker UncheckedBoundsClass.tpe) + else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) } // ===================================== } |