aboutsummaryrefslogtreecommitdiff
path: root/src/dotty/tools/dotc/transform
diff options
context:
space:
mode:
Diffstat (limited to 'src/dotty/tools/dotc/transform')
-rw-r--r--src/dotty/tools/dotc/transform/TreeTransform.scala27
1 files changed, 23 insertions, 4 deletions
diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala
index a70ab8aed..e12f6a8a6 100644
--- a/src/dotty/tools/dotc/transform/TreeTransform.scala
+++ b/src/dotty/tools/dotc/transform/TreeTransform.scala
@@ -92,6 +92,8 @@ object TreeTransforms {
def prepareForPackageDef(tree: PackageDef)(implicit ctx: Context) = this
def prepareForStats(trees: List[Tree])(implicit ctx: Context) = this
+ def prepareForUnit(tree: Tree)(implicit ctx: Context) = this
+
def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = tree
def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = tree
def transformThis(tree: This)(implicit ctx: Context, info: TransformerInfo): Tree = tree
@@ -125,6 +127,8 @@ object TreeTransforms {
def transformStats(trees: List[Tree])(implicit ctx: Context, info: TransformerInfo): List[Tree] = trees
def transformOther(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree
+ def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = tree
+
/** Transform tree using all transforms of current group (including this one) */
def transform(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = info.group.transform(tree, info, 0)
@@ -273,6 +277,7 @@ object TreeTransforms {
nxPrepTemplate = index(transformations, "prepareForTemplate")
nxPrepPackageDef = index(transformations, "prepareForPackageDef")
nxPrepStats = index(transformations, "prepareForStats")
+ nxPrepUnit = index(transformations, "prepareForUnit")
nxTransIdent = index(transformations, "transformIdent")
nxTransSelect = index(transformations, "transformSelect")
@@ -305,6 +310,7 @@ object TreeTransforms {
nxTransTemplate = index(transformations, "transformTemplate")
nxTransPackageDef = index(transformations, "transformPackageDef")
nxTransStats = index(transformations, "transformStats")
+ nxTransUnit = index(transformations, "transformUnit")
nxTransOther = index(transformations, "transformOther")
}
@@ -412,6 +418,7 @@ object TreeTransforms {
var nxPrepTemplate: Array[Int] = _
var nxPrepPackageDef: Array[Int] = _
var nxPrepStats: Array[Int] = _
+ var nxPrepUnit: Array[Int] = _
var nxTransIdent: Array[Int] = _
var nxTransSelect: Array[Int] = _
@@ -444,6 +451,7 @@ object TreeTransforms {
var nxTransTemplate: Array[Int] = _
var nxTransPackageDef: Array[Int] = _
var nxTransStats: Array[Int] = _
+ var nxTransUnit: Array[Int] = _
var nxTransOther: Array[Int] = _
}
@@ -454,7 +462,7 @@ object TreeTransforms {
override def run(implicit ctx: Context): Unit = {
val curTree = ctx.compilationUnit.tpdTree
- val newTree = transform(curTree)
+ val newTree = macroTransform(curTree)
ctx.compilationUnit.tpdTree = newTree
}
@@ -517,8 +525,9 @@ object TreeTransforms {
val prepForTemplate: Mutator[Template] = (trans, tree, ctx) => trans.prepareForTemplate(tree)(ctx)
val prepForPackageDef: Mutator[PackageDef] = (trans, tree, ctx) => trans.prepareForPackageDef(tree)(ctx)
val prepForStats: Mutator[List[Tree]] = (trans, trees, ctx) => trans.prepareForStats(trees)(ctx)
+ val prepForUnit: Mutator[Tree] = (trans, tree, ctx) => trans.prepareForUnit(tree)(ctx)
- def transform(t: Tree)(implicit ctx: Context): Tree = {
+ def macroTransform(t: Tree)(implicit ctx: Context): Tree = {
val initialTransformations = transformations
val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this)
initialTransformations.zipWithIndex.foreach {
@@ -526,7 +535,9 @@ object TreeTransforms {
transform.idx = id
transform.init(ctx, info)
}
- transform(t, info, 0)
+ implicit val mutatedInfo: TransformerInfo = mutateTransformers(info, prepForUnit, info.nx.nxPrepUnit, t, 0)
+ if (mutatedInfo eq null) t
+ else goUnit(transform(t, mutatedInfo, 0), mutatedInfo.nx.nxTransUnit(0))
}
@tailrec
@@ -859,6 +870,15 @@ object TreeTransforms {
} else tree
}
+ @tailrec
+ final private[TreeTransforms] def goUnit(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
+ if (cur < info.transformers.length) {
+ val trans = info.transformers(cur)
+ val t = trans.transformUnit(tree)(ctx.withPhase(trans.treeTransformPhase), info)
+ goUnit(t, info.nx.nxTransUnit(cur + 1))
+ } else tree
+ }
+
final private[TreeTransforms] def goOther(tree: Tree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
@@ -1219,5 +1239,4 @@ object TreeTransforms {
def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] =
transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]]
}
-
}