diff options
-rw-r--r-- | src/dotty/tools/dotc/Compiler.scala | 2 | ||||
-rw-r--r-- | src/dotty/tools/dotc/core/Contexts.scala | 3 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/PostTyperTransformers.scala | 4 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/Splitter.scala | 27 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/TreeTransform.scala | 159 |
5 files changed, 115 insertions, 80 deletions
diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index 6fd69beb8..1e8f13578 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -21,7 +21,7 @@ class Compiler { List( List(new FrontEnd), List(new LazyValsCreateCompanionObjects, new PatternMatcher), //force separataion between lazyVals and LVCreateCO - List(new LazyValTranformContext().transformer, new TypeTestsCasts), + List(new LazyValTranformContext().transformer, new Splitter, new TypeTestsCasts), List(new Erasure), List(new UncurryTreeTransform) ) diff --git a/src/dotty/tools/dotc/core/Contexts.scala b/src/dotty/tools/dotc/core/Contexts.scala index 8d083b29c..8721bc548 100644 --- a/src/dotty/tools/dotc/core/Contexts.scala +++ b/src/dotty/tools/dotc/core/Contexts.scala @@ -277,6 +277,9 @@ object Contexts { newctx.asInstanceOf[FreshContext] } + final def withOwner(owner: Symbol): Context = + if (owner ne this.owner) fresh.setOwner(owner) else this + final def withMode(mode: Mode): Context = if (mode != this.mode) fresh.setMode(mode) else this diff --git a/src/dotty/tools/dotc/transform/PostTyperTransformers.scala b/src/dotty/tools/dotc/transform/PostTyperTransformers.scala index 14e2cf35d..25f122cf5 100644 --- a/src/dotty/tools/dotc/transform/PostTyperTransformers.scala +++ b/src/dotty/tools/dotc/transform/PostTyperTransformers.scala @@ -48,8 +48,8 @@ object PostTyperTransformers { reorder0(stats) } - override def transformStats(trees: List[tpd.Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[tpd.Tree] = - super.transformStats(reorder(trees)(ctx, info), info, current) + override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol, info: TransformerInfo, current: Int)(implicit ctx: Context): List[tpd.Tree] = + super.transformStats(reorder(trees)(ctx, info), exprOwner, info, current) override def transform(tree: tpd.Tree, info: TransformerInfo, cur: Int)(implicit ctx: Context): tpd.Tree = tree match { case tree: Import => EmptyTree diff --git a/src/dotty/tools/dotc/transform/Splitter.scala b/src/dotty/tools/dotc/transform/Splitter.scala new file mode 100644 index 000000000..9c01574aa --- /dev/null +++ b/src/dotty/tools/dotc/transform/Splitter.scala @@ -0,0 +1,27 @@ +package dotty.tools.dotc +package transform + +import TreeTransforms._ +import ast.Trees._ +import core.Contexts._ +import core.Types._ + +/** This transform makes usre every identifier and select node + * carries a symbol. To do this, certain qualifiers with a union type + * have to be "splitted" with a type test. + * + * For now, only self references are treated. + */ +class Splitter extends TreeTransform { + import ast.tpd._ + + override def name: String = "splitter" + + /** Replace self referencing idents with ThisTypes. */ + override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo) = tree.tpe match { + case ThisType(cls) => + println(s"owner = ${ctx.owner}, context = ${ctx}") + This(cls) withPos tree.pos + case _ => tree + } +}
\ No newline at end of file diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala index 684714199..425410ae7 100644 --- a/src/dotty/tools/dotc/transform/TreeTransform.scala +++ b/src/dotty/tools/dotc/transform/TreeTransform.scala @@ -3,7 +3,9 @@ package dotty.tools.dotc.transform import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Phases.Phase +import dotty.tools.dotc.core.Symbols.Symbol import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.core.Decorators._ import scala.annotation.tailrec object TreeTransforms { @@ -149,9 +151,7 @@ object TreeTransforms { type Mutator[T] = (TreeTransform, T, Context) => TreeTransform - class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer, val contexts:Array[Context]) { - assert(transformers.size == contexts.size) - } + class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer) /** * This class maintains track of which methods are redefined in MiniPhases and creates execution plans for transformXXX and prepareXXX @@ -431,15 +431,15 @@ object TreeTransforms { val l = result.length var allDone = i < l while (i < l) { - val oldT = result(i) - val newT = mutator(oldT, tree, info.contexts(i)) - allDone = allDone && (newT eq NoTransform) - if (!(oldT eq newT)) { + val oldTransform = result(i) + val newTransform = mutator(oldTransform, tree, ctx.withPhase(oldTransform)) + allDone = allDone && (newTransform eq NoTransform) + if (!(oldTransform eq newTransform)) { if (!transformersCopied) result = result.clone() transformersCopied = true - result(i) = newT - if (!(newT.getClass == oldT.getClass)) { - resultNX = new NXTransformations(resultNX, newT, i, nxCopied) + result(i) = newTransform + if (!(newTransform.getClass == oldTransform.getClass)) { + resultNX = new NXTransformations(resultNX, newTransform, i, nxCopied) nxCopied = true } } @@ -447,7 +447,7 @@ object TreeTransforms { } if (allDone) null else if (!transformersCopied) info - else new TransformerInfo(result, resultNX, info.group, info.contexts) + else new TransformerInfo(result, resultNX, info.group) } val prepForIdent: Mutator[Ident] = (trans, tree, ctx) => trans.prepareForIdent(tree)(ctx) @@ -484,8 +484,7 @@ object TreeTransforms { def transform(t: Tree)(implicit ctx: Context): Tree = { val initialTransformations = transformations - val contexts = initialTransformations.map(tr => ctx.withPhase(tr).ctx) - val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this, contexts) + val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this) initialTransformations.zipWithIndex.foreach{ case (transform, id) => transform.idx = id @@ -498,8 +497,7 @@ object TreeTransforms { final private[TreeTransforms] def goIdent(tree: Ident, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - - trans.transformIdent(tree)(info.contexts(cur), info) match { + trans.transformIdent(tree)(ctx.withPhase(trans), info) match { case t: Ident => goIdent(t, info.nx.nxTransIdent(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -510,7 +508,7 @@ object TreeTransforms { final private[TreeTransforms] def goSelect(tree: Select, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSelect(tree)(info.contexts(cur), info) match { + trans.transformSelect(tree)(ctx.withPhase(trans), info) match { case t: Select => goSelect(t, info.nx.nxTransSelect(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -521,7 +519,7 @@ object TreeTransforms { final private[TreeTransforms] def goThis(tree: This, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformThis(tree)(info.contexts(cur), info) match { + trans.transformThis(tree)(ctx.withPhase(trans), info) match { case t: This => goThis(t, info.nx.nxTransThis(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -532,7 +530,7 @@ object TreeTransforms { final private[TreeTransforms] def goSuper(tree: Super, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSuper(tree)(info.contexts(cur), info) match { + trans.transformSuper(tree)(ctx.withPhase(trans), info) match { case t: Super => goSuper(t, info.nx.nxTransSuper(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -543,7 +541,7 @@ object TreeTransforms { final private[TreeTransforms] def goApply(tree: Apply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformApply(tree)(info.contexts(cur), info) match { + trans.transformApply(tree)(ctx.withPhase(trans), info) match { case t: Apply => goApply(t, info.nx.nxTransApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -554,7 +552,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeApply(tree: TypeApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeApply(tree)(info.contexts(cur), info) match { + trans.transformTypeApply(tree)(ctx.withPhase(trans), info) match { case t: TypeApply => goTypeApply(t, info.nx.nxTransTypeApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -565,7 +563,7 @@ object TreeTransforms { final private[TreeTransforms] def goNew(tree: New, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformNew(tree)(info.contexts(cur), info) match { + trans.transformNew(tree)(ctx.withPhase(trans), info) match { case t: New => goNew(t, info.nx.nxTransNew(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -576,7 +574,7 @@ object TreeTransforms { final private[TreeTransforms] def goPair(tree: Pair, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformPair(tree)(info.contexts(cur), info) match { + trans.transformPair(tree)(ctx.withPhase(trans), info) match { case t: Pair => goPair(t, info.nx.nxTransPair(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -587,7 +585,7 @@ object TreeTransforms { final private[TreeTransforms] def goTyped(tree: Typed, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTyped(tree)(info.contexts(cur), info) match { + trans.transformTyped(tree)(ctx.withPhase(trans), info) match { case t: Typed => goTyped(t, info.nx.nxTransTyped(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -598,7 +596,7 @@ object TreeTransforms { final private[TreeTransforms] def goAssign(tree: Assign, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformAssign(tree)(info.contexts(cur), info) match { + trans.transformAssign(tree)(ctx.withPhase(trans), info) match { case t: Assign => goAssign(t, info.nx.nxTransAssign(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -609,7 +607,7 @@ object TreeTransforms { final private[TreeTransforms] def goLiteral(tree: Literal, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformLiteral(tree)(info.contexts(cur), info) match { + trans.transformLiteral(tree)(ctx.withPhase(trans), info) match { case t: Literal => goLiteral(t, info.nx.nxTransLiteral(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -620,7 +618,7 @@ object TreeTransforms { final private[TreeTransforms] def goBlock(tree: Block, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformBlock(tree)(info.contexts(cur), info) match { + trans.transformBlock(tree)(ctx.withPhase(trans), info) match { case t: Block => goBlock(t, info.nx.nxTransBlock(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -631,7 +629,7 @@ object TreeTransforms { final private[TreeTransforms] def goIf(tree: If, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformIf(tree)(info.contexts(cur), info) match { + trans.transformIf(tree)(ctx.withPhase(trans), info) match { case t: If => goIf(t, info.nx.nxTransIf(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -642,7 +640,7 @@ object TreeTransforms { final private[TreeTransforms] def goClosure(tree: Closure, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformClosure(tree)(info.contexts(cur), info) match { + trans.transformClosure(tree)(ctx.withPhase(trans), info) match { case t: Closure => goClosure(t, info.nx.nxTransClosure(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -653,7 +651,7 @@ object TreeTransforms { final private[TreeTransforms] def goMatch(tree: Match, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformMatch(tree)(info.contexts(cur), info) match { + trans.transformMatch(tree)(ctx.withPhase(trans), info) match { case t: Match => goMatch(t, info.nx.nxTransMatch(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -664,7 +662,7 @@ object TreeTransforms { final private[TreeTransforms] def goCaseDef(tree: CaseDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformCaseDef(tree)(info.contexts(cur), info) match { + trans.transformCaseDef(tree)(ctx.withPhase(trans), info) match { case t: CaseDef => goCaseDef(t, info.nx.nxTransCaseDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -675,7 +673,7 @@ object TreeTransforms { final private[TreeTransforms] def goReturn(tree: Return, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformReturn(tree)(info.contexts(cur), info) match { + trans.transformReturn(tree)(ctx.withPhase(trans), info) match { case t: Return => goReturn(t, info.nx.nxTransReturn(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -686,7 +684,7 @@ object TreeTransforms { final private[TreeTransforms] def goTry(tree: Try, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTry(tree)(info.contexts(cur), info) match { + trans.transformTry(tree)(ctx.withPhase(trans), info) match { case t: Try => goTry(t, info.nx.nxTransTry(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -697,7 +695,7 @@ object TreeTransforms { final private[TreeTransforms] def goThrow(tree: Throw, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformThrow(tree)(info.contexts(cur), info) match { + trans.transformThrow(tree)(ctx.withPhase(trans), info) match { case t: Throw => goThrow(t, info.nx.nxTransThrow(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -708,7 +706,7 @@ object TreeTransforms { final private[TreeTransforms] def goSeqLiteral(tree: SeqLiteral, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSeqLiteral(tree)(info.contexts(cur), info) match { + trans.transformSeqLiteral(tree)(ctx.withPhase(trans), info) match { case t: SeqLiteral => goSeqLiteral(t, info.nx.nxTransSeqLiteral(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -719,7 +717,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeTree(tree: TypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeTree(tree)(info.contexts(cur), info) match { + trans.transformTypeTree(tree)(ctx.withPhase(trans), info) match { case t: TypeTree => goTypeTree(t, info.nx.nxTransTypeTree(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -730,7 +728,7 @@ object TreeTransforms { final private[TreeTransforms] def goSelectFromTypeTree(tree: SelectFromTypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSelectFromTypeTree(tree)(info.contexts(cur), info) match { + trans.transformSelectFromTypeTree(tree)(ctx.withPhase(trans), info) match { case t: SelectFromTypeTree => goSelectFromTypeTree(t, info.nx.nxTransSelectFromTypeTree(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -741,7 +739,7 @@ object TreeTransforms { final private[TreeTransforms] def goBind(tree: Bind, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformBind(tree)(info.contexts(cur), info) match { + trans.transformBind(tree)(ctx.withPhase(trans), info) match { case t: Bind => goBind(t, info.nx.nxTransBind(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -752,7 +750,7 @@ object TreeTransforms { final private[TreeTransforms] def goAlternative(tree: Alternative, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformAlternative(tree)(info.contexts(cur), info) match { + trans.transformAlternative(tree)(ctx.withPhase(trans), info) match { case t: Alternative => goAlternative(t, info.nx.nxTransAlternative(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -763,7 +761,7 @@ object TreeTransforms { final private[TreeTransforms] def goValDef(tree: ValDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformValDef(tree)(info.contexts(cur), info) match { + trans.transformValDef(tree)(ctx.withPhase(trans), info) match { case t: ValDef => goValDef(t, info.nx.nxTransValDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -774,7 +772,7 @@ object TreeTransforms { final private[TreeTransforms] def goDefDef(tree: DefDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformDefDef(tree)(info.contexts(cur), info) match { + trans.transformDefDef(tree)(ctx.withPhase(trans), info) match { case t: DefDef => goDefDef(t, info.nx.nxTransDefDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -785,7 +783,7 @@ object TreeTransforms { final private[TreeTransforms] def goUnApply(tree: UnApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformUnApply(tree)(info.contexts(cur), info) match { + trans.transformUnApply(tree)(ctx.withPhase(trans), info) match { case t: UnApply => goUnApply(t, info.nx.nxTransUnApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -796,7 +794,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeDef(tree: TypeDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeDef(tree)(info.contexts(cur), info) match { + trans.transformTypeDef(tree)(ctx.withPhase(trans), info) match { case t: TypeDef => goTypeDef(t, info.nx.nxTransTypeDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -807,7 +805,7 @@ object TreeTransforms { final private[TreeTransforms] def goTemplate(tree: Template, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTemplate(tree)(info.contexts(cur), info) match { + trans.transformTemplate(tree)(ctx.withPhase(trans), info) match { case t: Template => goTemplate(t, info.nx.nxTransTemplate(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -818,7 +816,7 @@ object TreeTransforms { final private[TreeTransforms] def goPackageDef(tree: PackageDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformPackageDef(tree)(info.contexts(cur), info) match { + trans.transformPackageDef(tree)(ctx.withPhase(trans), info) match { case t: PackageDef => goPackageDef(t, info.nx.nxTransPackageDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -863,9 +861,7 @@ object TreeTransforms { case tree: UnApply => goUnApply(tree, info.nx.nxTransUnApply(cur)) case tree: Template => goTemplate(tree, info.nx.nxTransTemplate(cur)) case tree: PackageDef => goPackageDef(tree, info.nx.nxTransPackageDef(cur)) - case Thicket(trees) if trees != Nil => - val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur) - if (trees1 eq trees) tree else Thicket(trees1) + case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur)) case tree => tree } @@ -876,7 +872,9 @@ object TreeTransforms { case tree => goUnamed(tree, cur) } - final private[TreeTransforms] def transformNameTree(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = + def localContext(owner: Symbol)(implicit ctx: Context) = ctx.fresh.setOwner(owner) + + final private[TreeTransforms] def transformNamed(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = tree match { case tree: Ident => implicit val mutatedInfo = mutateTransformers(info, prepForIdent, info.nx.nxPrepIdent, tree, cur) @@ -907,25 +905,27 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForValDef, info.nx.nxPrepValDef, tree, cur) if (mutatedInfo eq null) tree else { - val tpt = transform(tree.tpt, mutatedInfo, cur) - val rhs = transform(tree.rhs, mutatedInfo, cur) + val nestedCtx = if (tree.symbol.exists) localContext(tree.symbol) else ctx + val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx) + val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx) goValDef(cpy.ValDef(tree, tree.mods, tree.name, tpt, rhs), mutatedInfo.nx.nxTransValDef(cur)) } case tree: DefDef => implicit val mutatedInfo = mutateTransformers(info, prepForDefDef, info.nx.nxPrepDefDef, tree, cur) if (mutatedInfo eq null) tree else { - val tparams = transformSubL(tree.tparams, mutatedInfo, cur) - val vparams = tree.vparamss.mapConserve(x => transformSubL(x, mutatedInfo, cur)) - val tpt = transform(tree.tpt, mutatedInfo, cur) - val rhs = transform(tree.rhs, mutatedInfo, cur) + val nestedCtx = localContext(tree.symbol) + val tparams = transformSubTrees(tree.tparams, mutatedInfo, cur)(nestedCtx) + val vparams = tree.vparamss.mapConserve(x => transformSubTrees(x, mutatedInfo, cur)(nestedCtx)) + val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx) + val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx) goDefDef(cpy.DefDef(tree, tree.mods, tree.name, tparams, vparams, tpt, rhs), mutatedInfo.nx.nxTransDefDef(cur)) } case tree: TypeDef => implicit val mutatedInfo = mutateTransformers(info, prepForTypeDef, info.nx.nxPrepTypeDef, tree, cur) if (mutatedInfo eq null) tree else { - val rhs = transform(tree.rhs, mutatedInfo, cur) + val rhs = transform(tree.rhs, mutatedInfo, cur)(localContext(tree.symbol)) goTypeDef(cpy.TypeDef(tree, tree.mods, tree.name, rhs, tree.tparams), mutatedInfo.nx.nxTransTypeDef(cur)) } case _ => @@ -950,7 +950,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val args = transformSubL(tree.args, mutatedInfo, cur) + val args = transformSubTrees(tree.args, mutatedInfo, cur) goApply(cpy.Apply(tree, fun, args), mutatedInfo.nx.nxTransApply(cur)) } case tree: TypeApply => @@ -958,7 +958,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val args = transformL(tree.args, mutatedInfo, cur) + val args = transformTrees(tree.args, mutatedInfo, cur) goTypeApply(cpy.TypeApply(tree, fun, args), mutatedInfo.nx.nxTransTypeApply(cur)) } case tree: Literal => @@ -1000,7 +1000,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForBlock, info.nx.nxPrepBlock, tree, cur) if (mutatedInfo eq null) tree else { - val stats = transformStats(tree.stats, mutatedInfo, cur) + val stats = transformStats(tree.stats, ctx.owner, mutatedInfo, cur) val expr = transform(tree.expr, mutatedInfo, cur) goBlock(cpy.Block(tree, stats, expr), mutatedInfo.nx.nxTransBlock(cur)) } @@ -1017,7 +1017,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForClosure, info.nx.nxPrepClosure, tree, cur) if (mutatedInfo eq null) tree else { - val env = transformL(tree.env, mutatedInfo, cur) + val env = transformTrees(tree.env, mutatedInfo, cur) val meth = transform(tree.meth, mutatedInfo, cur) val tpt = transform(tree.tpt, mutatedInfo, cur) goClosure(cpy.Closure(tree, env, meth, tpt), mutatedInfo.nx.nxTransClosure(cur)) @@ -1027,7 +1027,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val selector = transform(tree.selector, mutatedInfo, cur) - val cases = transformSubL(tree.cases, mutatedInfo, cur) + val cases = transformSubTrees(tree.cases, mutatedInfo, cur) goMatch(cpy.Match(tree, selector, cases), mutatedInfo.nx.nxTransMatch(cur)) } case tree: CaseDef => @@ -1067,7 +1067,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForSeqLiteral, info.nx.nxPrepSeqLiteral, tree, cur) if (mutatedInfo eq null) tree else { - val elems = transformL(tree.elems, mutatedInfo, cur) + val elems = transformTrees(tree.elems, mutatedInfo, cur) goSeqLiteral(cpy.SeqLiteral(tree, elems), mutatedInfo.nx.nxTransLiteral(cur)) } case tree: TypeTree => @@ -1081,7 +1081,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForAlternative, info.nx.nxPrepAlternative, tree, cur) if (mutatedInfo eq null) tree else { - val trees = transformL(tree.trees, mutatedInfo, cur) + val trees = transformTrees(tree.trees, mutatedInfo, cur) goAlternative(cpy.Alternative(tree, trees), mutatedInfo.nx.nxTransAlternative(cur)) } case tree: UnApply => @@ -1089,8 +1089,8 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val implicits = transformL(tree.implicits, mutatedInfo, cur) - val patterns = transformL(tree.patterns, mutatedInfo, cur) + val implicits = transformTrees(tree.implicits, mutatedInfo, cur) + val patterns = transformTrees(tree.patterns, mutatedInfo, cur) goUnApply(cpy.UnApply(tree, fun, implicits, patterns), mutatedInfo.nx.nxTransUnApply(cur)) } case tree: Template => @@ -1098,29 +1098,28 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val constr = transformSub(tree.constr, mutatedInfo, cur) - val parents = transformL(tree.parents, mutatedInfo, cur) + val parents = transformTrees(tree.parents, mutatedInfo, cur) val self = transformSub(tree.self, mutatedInfo, cur) - val body = transformStats(tree.body, mutatedInfo, cur) + val body = transformStats(tree.body, tree.symbol, mutatedInfo, cur) goTemplate(cpy.Template(tree, constr, parents, self, body), mutatedInfo.nx.nxTransTemplate(cur)) } case tree: PackageDef => implicit val mutatedInfo = mutateTransformers(info, prepForPackageDef, info.nx.nxPrepPackageDef, tree, cur) if (mutatedInfo eq null) tree else { + val nestedCtx = localContext(tree.symbol) val pid = transformSub(tree.pid, mutatedInfo, cur) - val stats = transformStats(tree.stats, mutatedInfo, cur) + val stats = transformStats(tree.stats, tree.symbol, mutatedInfo, cur)(nestedCtx) goPackageDef(cpy.PackageDef(tree, pid, stats), mutatedInfo.nx.nxTransPackageDef(cur)) } - case Thicket(trees) if trees != Nil => - val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur) - if (trees1 eq trees) tree else Thicket(trees1) + case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur)) case tree => tree } def transform(tree: Tree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = { tree match { //split one big match into 2 smaller ones - case tree: NameTree => transformNameTree(tree, info, cur) + case tree: NameTree => transformNamed(tree, info, cur) case tree => transformUnnamed(tree, info, cur) } } @@ -1134,20 +1133,26 @@ object TreeTransforms { } else trees } - def transformStats(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = { + def transformStats(trees: List[Tree], exprOwner: Symbol, info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = { val newInfo = mutateTransformers(info, prepForStats, info.nx.nxPrepStats, trees, current) - val newTrees = transformL(trees, newInfo, current)(ctx) - flatten(goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo)) + val exprCtx = ctx.withOwner(exprOwner) + def transformStat(stat: Tree): Tree = stat match { + case _: Import | _: DefTree => transform(stat, info, current) + case Thicket(stats) => cpy.Thicket(stat, stats mapConserve transformStat) + case _ => transform(stat, info, current)(exprCtx) + } + val newTrees = flatten(trees.mapconserve(transformStat)) + goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo) } - def transformL(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = + def transformTrees(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = flatten(trees mapConserve (x => transform(x, info, current))) def transformSub[Tr <: Tree](tree: Tr, info: TransformerInfo, current: Int)(implicit ctx: Context): Tr = transform(tree, info, current).asInstanceOf[Tr] - def transformSubL[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = - transformL(trees, info, current)(ctx).asInstanceOf[List[Tr]] + def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = + transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]] } } |