diff options
Diffstat (limited to 'src/dotty/tools/dotc/transform')
-rw-r--r-- | src/dotty/tools/dotc/transform/TreeTransform.scala | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala index a70ab8aed..e12f6a8a6 100644 --- a/src/dotty/tools/dotc/transform/TreeTransform.scala +++ b/src/dotty/tools/dotc/transform/TreeTransform.scala @@ -92,6 +92,8 @@ object TreeTransforms { def prepareForPackageDef(tree: PackageDef)(implicit ctx: Context) = this def prepareForStats(trees: List[Tree])(implicit ctx: Context) = this + def prepareForUnit(tree: Tree)(implicit ctx: Context) = this + def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = tree def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = tree def transformThis(tree: This)(implicit ctx: Context, info: TransformerInfo): Tree = tree @@ -125,6 +127,8 @@ object TreeTransforms { def transformStats(trees: List[Tree])(implicit ctx: Context, info: TransformerInfo): List[Tree] = trees def transformOther(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree + def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree + /** Transform tree using all transforms of current group (including this one) */ def transform(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = info.group.transform(tree, info, 0) @@ -273,6 +277,7 @@ object TreeTransforms { nxPrepTemplate = index(transformations, "prepareForTemplate") nxPrepPackageDef = index(transformations, "prepareForPackageDef") nxPrepStats = index(transformations, "prepareForStats") + nxPrepUnit = index(transformations, "prepareForUnit") nxTransIdent = index(transformations, "transformIdent") nxTransSelect = index(transformations, "transformSelect") @@ -305,6 +310,7 @@ object TreeTransforms { nxTransTemplate = index(transformations, "transformTemplate") nxTransPackageDef = index(transformations, "transformPackageDef") nxTransStats = index(transformations, "transformStats") + nxTransUnit = index(transformations, "transformUnit") nxTransOther = index(transformations, "transformOther") } @@ -412,6 +418,7 @@ object TreeTransforms { var nxPrepTemplate: Array[Int] = _ var nxPrepPackageDef: Array[Int] = _ var nxPrepStats: Array[Int] = _ + var nxPrepUnit: Array[Int] = _ var nxTransIdent: Array[Int] = _ var nxTransSelect: Array[Int] = _ @@ -444,6 +451,7 @@ object TreeTransforms { var nxTransTemplate: Array[Int] = _ var nxTransPackageDef: Array[Int] = _ var nxTransStats: Array[Int] = _ + var nxTransUnit: Array[Int] = _ var nxTransOther: Array[Int] = _ } @@ -454,7 +462,7 @@ object TreeTransforms { override def run(implicit ctx: Context): Unit = { val curTree = ctx.compilationUnit.tpdTree - val newTree = transform(curTree) + val newTree = macroTransform(curTree) ctx.compilationUnit.tpdTree = newTree } @@ -517,8 +525,9 @@ object TreeTransforms { val prepForTemplate: Mutator[Template] = (trans, tree, ctx) => trans.prepareForTemplate(tree)(ctx) val prepForPackageDef: Mutator[PackageDef] = (trans, tree, ctx) => trans.prepareForPackageDef(tree)(ctx) val prepForStats: Mutator[List[Tree]] = (trans, trees, ctx) => trans.prepareForStats(trees)(ctx) + val prepForUnit: Mutator[Tree] = (trans, tree, ctx) => trans.prepareForUnit(tree)(ctx) - def transform(t: Tree)(implicit ctx: Context): Tree = { + def macroTransform(t: Tree)(implicit ctx: Context): Tree = { val initialTransformations = transformations val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this) initialTransformations.zipWithIndex.foreach { @@ -526,7 +535,9 @@ object TreeTransforms { transform.idx = id transform.init(ctx, info) } - transform(t, info, 0) + implicit val mutatedInfo: TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0) + if (mutatedInfo eq null) t + else goUnit(transform(t, mutatedInfo, 0), mutatedInfo.nx.nxTransUnit(0)) } @tailrec @@ -859,6 +870,15 @@ object TreeTransforms { } else tree } + @tailrec + final private[TreeTransforms] def goUnit(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { + if (cur < info.transformers.length) { + val trans = info.transformers(cur) + val t = trans.transformUnit(tree)(ctx.withPhase(trans.treeTransformPhase), info) + goUnit(t, info.nx.nxTransUnit(cur + 1)) + } else tree + } + final private[TreeTransforms] def goOther(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = { if (cur < info.transformers.length) { val trans = info.transformers(cur) @@ -1219,5 +1239,4 @@ object TreeTransforms { def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] = transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]] } - } |