diff options
author | Martin Odersky <odersky@gmail.com> | 2014-03-30 14:41:09 +0200 |
---|---|---|
committer | Dmitry Petrashko <dmitry.petrashko@gmail.com> | 2014-03-31 14:52:27 +0200 |
commit | 5c1a1498c7215935d466379d7fa85f88f4a001c7 (patch) | |
tree | 9825be12a157486430f8e0229f811f1f436b9847 /src/dotty/tools/dotc/transform/TreeTransform.scala | |
parent | 9bd1e6a99e1cb09a3527e548699d1561e72e36d3 (diff) | |
download | dotty-5c1a1498c7215935d466379d7fa85f88f4a001c7.tar.gz dotty-5c1a1498c7215935d466379d7fa85f88f4a001c7.tar.bz2 dotty-5c1a1498c7215935d466379d7fa85f88f4a001c7.zip |
Maintaining owners during transformations
The transformation framework needed to be changed so that contexts passed to
transformations have correct owner chains. These owner chins are demanded by
the Splitter phase.
Note: I eliminated the contexts array in TransformInfo because it interfered
with the owner computations. Generally, caching contexts with some phase is best
done in Contexts, because withPhase is also used heavily in othre code, not just in
Transformers.
New phase: Splitter
When it is complete, it will make sure that every term Ident and Select node
carries a symbol. Right now, all it does is coverting self reference idents to
"this"-nodes.
Diffstat (limited to 'src/dotty/tools/dotc/transform/TreeTransform.scala')
-rw-r--r-- | src/dotty/tools/dotc/transform/TreeTransform.scala | 159 |
1 files changed, 82 insertions, 77 deletions
diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala index 684714199..425410ae7 100644 --- a/src/dotty/tools/dotc/transform/TreeTransform.scala +++ b/src/dotty/tools/dotc/transform/TreeTransform.scala @@ -3,7 +3,9 @@ package dotty.tools.dotc.transform import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Phases.Phase +import dotty.tools.dotc.core.Symbols.Symbol import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.core.Decorators._ import scala.annotation.tailrec object TreeTransforms { @@ -149,9 +151,7 @@ object TreeTransforms { type Mutator[T] = (TreeTransform, T, Context) => TreeTransform - class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer, val contexts:Array[Context]) { - assert(transformers.size == contexts.size) - } + class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer) /** * This class maintains track of which methods are redefined in MiniPhases and creates execution plans for transformXXX and prepareXXX @@ -431,15 +431,15 @@ object TreeTransforms { val l = result.length var allDone = i < l while (i < l) { - val oldT = result(i) - val newT = mutator(oldT, tree, info.contexts(i)) - allDone = allDone && (newT eq NoTransform) - if (!(oldT eq newT)) { + val oldTransform = result(i) + val newTransform = mutator(oldTransform, tree, ctx.withPhase(oldTransform)) + allDone = allDone && (newTransform eq NoTransform) + if (!(oldTransform eq newTransform)) { if (!transformersCopied) result = result.clone() transformersCopied = true - result(i) = newT - if (!(newT.getClass == oldT.getClass)) { - resultNX = new NXTransformations(resultNX, newT, i, nxCopied) + result(i) = newTransform + if (!(newTransform.getClass == oldTransform.getClass)) { + resultNX = new NXTransformations(resultNX, newTransform, i, nxCopied) nxCopied = true } } @@ -447,7 +447,7 @@ object TreeTransforms { } if (allDone) null else if (!transformersCopied) info - else new TransformerInfo(result, resultNX, info.group, info.contexts) + else new TransformerInfo(result, resultNX, info.group) } val prepForIdent: Mutator[Ident] = (trans, tree, ctx) => trans.prepareForIdent(tree)(ctx) @@ -484,8 +484,7 @@ object TreeTransforms { def transform(t: Tree)(implicit ctx: Context): Tree = { val initialTransformations = transformations - val contexts = initialTransformations.map(tr => ctx.withPhase(tr).ctx) - val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this, contexts) + val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this) initialTransformations.zipWithIndex.foreach{ case (transform, id) => transform.idx = id @@ -498,8 +497,7 @@ object TreeTransforms { final private[TreeTransforms] def goIdent(tree: Ident, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - - trans.transformIdent(tree)(info.contexts(cur), info) match { + trans.transformIdent(tree)(ctx.withPhase(trans), info) match { case t: Ident => goIdent(t, info.nx.nxTransIdent(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -510,7 +508,7 @@ object TreeTransforms { final private[TreeTransforms] def goSelect(tree: Select, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSelect(tree)(info.contexts(cur), info) match { + trans.transformSelect(tree)(ctx.withPhase(trans), info) match { case t: Select => goSelect(t, info.nx.nxTransSelect(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -521,7 +519,7 @@ object TreeTransforms { final private[TreeTransforms] def goThis(tree: This, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformThis(tree)(info.contexts(cur), info) match { + trans.transformThis(tree)(ctx.withPhase(trans), info) match { case t: This => goThis(t, info.nx.nxTransThis(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -532,7 +530,7 @@ object TreeTransforms { final private[TreeTransforms] def goSuper(tree: Super, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSuper(tree)(info.contexts(cur), info) match { + trans.transformSuper(tree)(ctx.withPhase(trans), info) match { case t: Super => goSuper(t, info.nx.nxTransSuper(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -543,7 +541,7 @@ object TreeTransforms { final private[TreeTransforms] def goApply(tree: Apply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformApply(tree)(info.contexts(cur), info) match { + trans.transformApply(tree)(ctx.withPhase(trans), info) match { case t: Apply => goApply(t, info.nx.nxTransApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -554,7 +552,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeApply(tree: TypeApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeApply(tree)(info.contexts(cur), info) match { + trans.transformTypeApply(tree)(ctx.withPhase(trans), info) match { case t: TypeApply => goTypeApply(t, info.nx.nxTransTypeApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -565,7 +563,7 @@ object TreeTransforms { final private[TreeTransforms] def goNew(tree: New, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformNew(tree)(info.contexts(cur), info) match { + trans.transformNew(tree)(ctx.withPhase(trans), info) match { case t: New => goNew(t, info.nx.nxTransNew(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -576,7 +574,7 @@ object TreeTransforms { final private[TreeTransforms] def goPair(tree: Pair, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformPair(tree)(info.contexts(cur), info) match { + trans.transformPair(tree)(ctx.withPhase(trans), info) match { case t: Pair => goPair(t, info.nx.nxTransPair(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -587,7 +585,7 @@ object TreeTransforms { final private[TreeTransforms] def goTyped(tree: Typed, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTyped(tree)(info.contexts(cur), info) match { + trans.transformTyped(tree)(ctx.withPhase(trans), info) match { case t: Typed => goTyped(t, info.nx.nxTransTyped(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -598,7 +596,7 @@ object TreeTransforms { final private[TreeTransforms] def goAssign(tree: Assign, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformAssign(tree)(info.contexts(cur), info) match { + trans.transformAssign(tree)(ctx.withPhase(trans), info) match { case t: Assign => goAssign(t, info.nx.nxTransAssign(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -609,7 +607,7 @@ object TreeTransforms { final private[TreeTransforms] def goLiteral(tree: Literal, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformLiteral(tree)(info.contexts(cur), info) match { + trans.transformLiteral(tree)(ctx.withPhase(trans), info) match { case t: Literal => goLiteral(t, info.nx.nxTransLiteral(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -620,7 +618,7 @@ object TreeTransforms { final private[TreeTransforms] def goBlock(tree: Block, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformBlock(tree)(info.contexts(cur), info) match { + trans.transformBlock(tree)(ctx.withPhase(trans), info) match { case t: Block => goBlock(t, info.nx.nxTransBlock(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -631,7 +629,7 @@ object TreeTransforms { final private[TreeTransforms] def goIf(tree: If, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformIf(tree)(info.contexts(cur), info) match { + trans.transformIf(tree)(ctx.withPhase(trans), info) match { case t: If => goIf(t, info.nx.nxTransIf(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -642,7 +640,7 @@ object TreeTransforms { final private[TreeTransforms] def goClosure(tree: Closure, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformClosure(tree)(info.contexts(cur), info) match { + trans.transformClosure(tree)(ctx.withPhase(trans), info) match { case t: Closure => goClosure(t, info.nx.nxTransClosure(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -653,7 +651,7 @@ object TreeTransforms { final private[TreeTransforms] def goMatch(tree: Match, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformMatch(tree)(info.contexts(cur), info) match { + trans.transformMatch(tree)(ctx.withPhase(trans), info) match { case t: Match => goMatch(t, info.nx.nxTransMatch(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -664,7 +662,7 @@ object TreeTransforms { final private[TreeTransforms] def goCaseDef(tree: CaseDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformCaseDef(tree)(info.contexts(cur), info) match { + trans.transformCaseDef(tree)(ctx.withPhase(trans), info) match { case t: CaseDef => goCaseDef(t, info.nx.nxTransCaseDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -675,7 +673,7 @@ object TreeTransforms { final private[TreeTransforms] def goReturn(tree: Return, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformReturn(tree)(info.contexts(cur), info) match { + trans.transformReturn(tree)(ctx.withPhase(trans), info) match { case t: Return => goReturn(t, info.nx.nxTransReturn(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -686,7 +684,7 @@ object TreeTransforms { final private[TreeTransforms] def goTry(tree: Try, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTry(tree)(info.contexts(cur), info) match { + trans.transformTry(tree)(ctx.withPhase(trans), info) match { case t: Try => goTry(t, info.nx.nxTransTry(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -697,7 +695,7 @@ object TreeTransforms { final private[TreeTransforms] def goThrow(tree: Throw, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformThrow(tree)(info.contexts(cur), info) match { + trans.transformThrow(tree)(ctx.withPhase(trans), info) match { case t: Throw => goThrow(t, info.nx.nxTransThrow(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -708,7 +706,7 @@ object TreeTransforms { final private[TreeTransforms] def goSeqLiteral(tree: SeqLiteral, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSeqLiteral(tree)(info.contexts(cur), info) match { + trans.transformSeqLiteral(tree)(ctx.withPhase(trans), info) match { case t: SeqLiteral => goSeqLiteral(t, info.nx.nxTransSeqLiteral(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -719,7 +717,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeTree(tree: TypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeTree(tree)(info.contexts(cur), info) match { + trans.transformTypeTree(tree)(ctx.withPhase(trans), info) match { case t: TypeTree => goTypeTree(t, info.nx.nxTransTypeTree(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -730,7 +728,7 @@ object TreeTransforms { final private[TreeTransforms] def goSelectFromTypeTree(tree: SelectFromTypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformSelectFromTypeTree(tree)(info.contexts(cur), info) match { + trans.transformSelectFromTypeTree(tree)(ctx.withPhase(trans), info) match { case t: SelectFromTypeTree => goSelectFromTypeTree(t, info.nx.nxTransSelectFromTypeTree(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -741,7 +739,7 @@ object TreeTransforms { final private[TreeTransforms] def goBind(tree: Bind, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformBind(tree)(info.contexts(cur), info) match { + trans.transformBind(tree)(ctx.withPhase(trans), info) match { case t: Bind => goBind(t, info.nx.nxTransBind(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -752,7 +750,7 @@ object TreeTransforms { final private[TreeTransforms] def goAlternative(tree: Alternative, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformAlternative(tree)(info.contexts(cur), info) match { + trans.transformAlternative(tree)(ctx.withPhase(trans), info) match { case t: Alternative => goAlternative(t, info.nx.nxTransAlternative(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -763,7 +761,7 @@ object TreeTransforms { final private[TreeTransforms] def goValDef(tree: ValDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformValDef(tree)(info.contexts(cur), info) match { + trans.transformValDef(tree)(ctx.withPhase(trans), info) match { case t: ValDef => goValDef(t, info.nx.nxTransValDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -774,7 +772,7 @@ object TreeTransforms { final private[TreeTransforms] def goDefDef(tree: DefDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformDefDef(tree)(info.contexts(cur), info) match { + trans.transformDefDef(tree)(ctx.withPhase(trans), info) match { case t: DefDef => goDefDef(t, info.nx.nxTransDefDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -785,7 +783,7 @@ object TreeTransforms { final private[TreeTransforms] def goUnApply(tree: UnApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformUnApply(tree)(info.contexts(cur), info) match { + trans.transformUnApply(tree)(ctx.withPhase(trans), info) match { case t: UnApply => goUnApply(t, info.nx.nxTransUnApply(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -796,7 +794,7 @@ object TreeTransforms { final private[TreeTransforms] def goTypeDef(tree: TypeDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTypeDef(tree)(info.contexts(cur), info) match { + trans.transformTypeDef(tree)(ctx.withPhase(trans), info) match { case t: TypeDef => goTypeDef(t, info.nx.nxTransTypeDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -807,7 +805,7 @@ object TreeTransforms { final private[TreeTransforms] def goTemplate(tree: Template, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformTemplate(tree)(info.contexts(cur), info) match { + trans.transformTemplate(tree)(ctx.withPhase(trans), info) match { case t: Template => goTemplate(t, info.nx.nxTransTemplate(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -818,7 +816,7 @@ object TreeTransforms { final private[TreeTransforms] def goPackageDef(tree: PackageDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) - trans.transformPackageDef(tree)(info.contexts(cur), info) match { + trans.transformPackageDef(tree)(ctx.withPhase(trans), info) match { case t: PackageDef => goPackageDef(t, info.nx.nxTransPackageDef(cur + 1)) case t => transformSingle(t, cur + 1) } @@ -863,9 +861,7 @@ object TreeTransforms { case tree: UnApply => goUnApply(tree, info.nx.nxTransUnApply(cur)) case tree: Template => goTemplate(tree, info.nx.nxTransTemplate(cur)) case tree: PackageDef => goPackageDef(tree, info.nx.nxTransPackageDef(cur)) - case Thicket(trees) if trees != Nil => - val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur) - if (trees1 eq trees) tree else Thicket(trees1) + case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur)) case tree => tree } @@ -876,7 +872,9 @@ object TreeTransforms { case tree => goUnamed(tree, cur) } - final private[TreeTransforms] def transformNameTree(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = + def localContext(owner: Symbol)(implicit ctx: Context) = ctx.fresh.setOwner(owner) + + final private[TreeTransforms] def transformNamed(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = tree match { case tree: Ident => implicit val mutatedInfo = mutateTransformers(info, prepForIdent, info.nx.nxPrepIdent, tree, cur) @@ -907,25 +905,27 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForValDef, info.nx.nxPrepValDef, tree, cur) if (mutatedInfo eq null) tree else { - val tpt = transform(tree.tpt, mutatedInfo, cur) - val rhs = transform(tree.rhs, mutatedInfo, cur) + val nestedCtx = if (tree.symbol.exists) localContext(tree.symbol) else ctx + val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx) + val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx) goValDef(cpy.ValDef(tree, tree.mods, tree.name, tpt, rhs), mutatedInfo.nx.nxTransValDef(cur)) } case tree: DefDef => implicit val mutatedInfo = mutateTransformers(info, prepForDefDef, info.nx.nxPrepDefDef, tree, cur) if (mutatedInfo eq null) tree else { - val tparams = transformSubL(tree.tparams, mutatedInfo, cur) - val vparams = tree.vparamss.mapConserve(x => transformSubL(x, mutatedInfo, cur)) - val tpt = transform(tree.tpt, mutatedInfo, cur) - val rhs = transform(tree.rhs, mutatedInfo, cur) + val nestedCtx = localContext(tree.symbol) + val tparams = transformSubTrees(tree.tparams, mutatedInfo, cur)(nestedCtx) + val vparams = tree.vparamss.mapConserve(x => transformSubTrees(x, mutatedInfo, cur)(nestedCtx)) + val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx) + val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx) goDefDef(cpy.DefDef(tree, tree.mods, tree.name, tparams, vparams, tpt, rhs), mutatedInfo.nx.nxTransDefDef(cur)) } case tree: TypeDef => implicit val mutatedInfo = mutateTransformers(info, prepForTypeDef, info.nx.nxPrepTypeDef, tree, cur) if (mutatedInfo eq null) tree else { - val rhs = transform(tree.rhs, mutatedInfo, cur) + val rhs = transform(tree.rhs, mutatedInfo, cur)(localContext(tree.symbol)) goTypeDef(cpy.TypeDef(tree, tree.mods, tree.name, rhs, tree.tparams), mutatedInfo.nx.nxTransTypeDef(cur)) } case _ => @@ -950,7 +950,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val args = transformSubL(tree.args, mutatedInfo, cur) + val args = transformSubTrees(tree.args, mutatedInfo, cur) goApply(cpy.Apply(tree, fun, args), mutatedInfo.nx.nxTransApply(cur)) } case tree: TypeApply => @@ -958,7 +958,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val args = transformL(tree.args, mutatedInfo, cur) + val args = transformTrees(tree.args, mutatedInfo, cur) goTypeApply(cpy.TypeApply(tree, fun, args), mutatedInfo.nx.nxTransTypeApply(cur)) } case tree: Literal => @@ -1000,7 +1000,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForBlock, info.nx.nxPrepBlock, tree, cur) if (mutatedInfo eq null) tree else { - val stats = transformStats(tree.stats, mutatedInfo, cur) + val stats = transformStats(tree.stats, ctx.owner, mutatedInfo, cur) val expr = transform(tree.expr, mutatedInfo, cur) goBlock(cpy.Block(tree, stats, expr), mutatedInfo.nx.nxTransBlock(cur)) } @@ -1017,7 +1017,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForClosure, info.nx.nxPrepClosure, tree, cur) if (mutatedInfo eq null) tree else { - val env = transformL(tree.env, mutatedInfo, cur) + val env = transformTrees(tree.env, mutatedInfo, cur) val meth = transform(tree.meth, mutatedInfo, cur) val tpt = transform(tree.tpt, mutatedInfo, cur) goClosure(cpy.Closure(tree, env, meth, tpt), mutatedInfo.nx.nxTransClosure(cur)) @@ -1027,7 +1027,7 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val selector = transform(tree.selector, mutatedInfo, cur) - val cases = transformSubL(tree.cases, mutatedInfo, cur) + val cases = transformSubTrees(tree.cases, mutatedInfo, cur) goMatch(cpy.Match(tree, selector, cases), mutatedInfo.nx.nxTransMatch(cur)) } case tree: CaseDef => @@ -1067,7 +1067,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForSeqLiteral, info.nx.nxPrepSeqLiteral, tree, cur) if (mutatedInfo eq null) tree else { - val elems = transformL(tree.elems, mutatedInfo, cur) + val elems = transformTrees(tree.elems, mutatedInfo, cur) goSeqLiteral(cpy.SeqLiteral(tree, elems), mutatedInfo.nx.nxTransLiteral(cur)) } case tree: TypeTree => @@ -1081,7 +1081,7 @@ object TreeTransforms { implicit val mutatedInfo = mutateTransformers(info, prepForAlternative, info.nx.nxPrepAlternative, tree, cur) if (mutatedInfo eq null) tree else { - val trees = transformL(tree.trees, mutatedInfo, cur) + val trees = transformTrees(tree.trees, mutatedInfo, cur) goAlternative(cpy.Alternative(tree, trees), mutatedInfo.nx.nxTransAlternative(cur)) } case tree: UnApply => @@ -1089,8 +1089,8 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val fun = transform(tree.fun, mutatedInfo, cur) - val implicits = transformL(tree.implicits, mutatedInfo, cur) - val patterns = transformL(tree.patterns, mutatedInfo, cur) + val implicits = transformTrees(tree.implicits, mutatedInfo, cur) + val patterns = transformTrees(tree.patterns, mutatedInfo, cur) goUnApply(cpy.UnApply(tree, fun, implicits, patterns), mutatedInfo.nx.nxTransUnApply(cur)) } case tree: Template => @@ -1098,29 +1098,28 @@ object TreeTransforms { if (mutatedInfo eq null) tree else { val constr = transformSub(tree.constr, mutatedInfo, cur) - val parents = transformL(tree.parents, mutatedInfo, cur) + val parents = transformTrees(tree.parents, mutatedInfo, cur) val self = transformSub(tree.self, mutatedInfo, cur) - val body = transformStats(tree.body, mutatedInfo, cur) + val body = transformStats(tree.body, tree.symbol, mutatedInfo, cur) goTemplate(cpy.Template(tree, constr, parents, self, body), mutatedInfo.nx.nxTransTemplate(cur)) } case tree: PackageDef => implicit val mutatedInfo = mutateTransformers(info, prepForPackageDef, info.nx.nxPrepPackageDef, tree, cur) if (mutatedInfo eq null) tree else { + val nestedCtx = localContext(tree.symbol) val pid = transformSub(tree.pid, mutatedInfo, cur) - val stats = transformStats(tree.stats, mutatedInfo, cur) + val stats = transformStats(tree.stats, tree.symbol, mutatedInfo, cur)(nestedCtx) goPackageDef(cpy.PackageDef(tree, pid, stats), mutatedInfo.nx.nxTransPackageDef(cur)) } - case Thicket(trees) if trees != Nil => - val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur) - if (trees1 eq trees) tree else Thicket(trees1) + case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur)) case tree => tree } def transform(tree: Tree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = { tree match { //split one big match into 2 smaller ones - case tree: NameTree => transformNameTree(tree, info, cur) + case tree: NameTree => transformNamed(tree, info, cur) case tree => transformUnnamed(tree, info, cur) } } @@ -1134,20 +1133,26 @@ object TreeTransforms { } else trees } - def transformStats(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = { + def transformStats(trees: List[Tree], exprOwner: Symbol, info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = { val newInfo = mutateTransformers(info, prepForStats, info.nx.nxPrepStats, trees, current) - val newTrees = transformL(trees, newInfo, current)(ctx) - flatten(goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo)) + val exprCtx = ctx.withOwner(exprOwner) + def transformStat(stat: Tree): Tree = stat match { + case _: Import | _: DefTree => transform(stat, info, current) + case Thicket(stats) => cpy.Thicket(stat, stats mapConserve transformStat) + case _ => transform(stat, info, current)(exprCtx) + } + val newTrees = flatten(trees.mapconserve(transformStat)) + goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo) } - def transformL(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = + def transformTrees(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = flatten(trees mapConserve (x => transform(x, info, current))) def transformSub[Tr <: Tree](tree: Tr, info: TransformerInfo, current: Int)(implicit ctx: Context): Tr = transform(tree, info, current).asInstanceOf[Tr] - def transformSubL[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = - transformL(trees, info, current)(ctx).asInstanceOf[List[Tr]] + def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = + transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]] } } |