aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/dotc/Compiler.scala2
-rw-r--r--src/dotty/tools/dotc/core/Contexts.scala3
-rw-r--r--src/dotty/tools/dotc/transform/PostTyperTransformers.scala4
-rw-r--r--src/dotty/tools/dotc/transform/Splitter.scala27
-rw-r--r--src/dotty/tools/dotc/transform/TreeTransform.scala159
5 files changed, 115 insertions, 80 deletions
diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala
index 6fd69beb8..1e8f13578 100644
--- a/src/dotty/tools/dotc/Compiler.scala
+++ b/src/dotty/tools/dotc/Compiler.scala
@@ -21,7 +21,7 @@ class Compiler {
List(
List(new FrontEnd),
List(new LazyValsCreateCompanionObjects, new PatternMatcher), //force separataion between lazyVals and LVCreateCO
- List(new LazyValTranformContext().transformer, new TypeTestsCasts),
+ List(new LazyValTranformContext().transformer, new Splitter, new TypeTestsCasts),
List(new Erasure),
List(new UncurryTreeTransform)
)
diff --git a/src/dotty/tools/dotc/core/Contexts.scala b/src/dotty/tools/dotc/core/Contexts.scala
index 8d083b29c..8721bc548 100644
--- a/src/dotty/tools/dotc/core/Contexts.scala
+++ b/src/dotty/tools/dotc/core/Contexts.scala
@@ -277,6 +277,9 @@ object Contexts {
newctx.asInstanceOf[FreshContext]
}
+ final def withOwner(owner: Symbol): Context =
+ if (owner ne this.owner) fresh.setOwner(owner) else this
+
final def withMode(mode: Mode): Context =
if (mode != this.mode) fresh.setMode(mode) else this
diff --git a/src/dotty/tools/dotc/transform/PostTyperTransformers.scala b/src/dotty/tools/dotc/transform/PostTyperTransformers.scala
index 14e2cf35d..25f122cf5 100644
--- a/src/dotty/tools/dotc/transform/PostTyperTransformers.scala
+++ b/src/dotty/tools/dotc/transform/PostTyperTransformers.scala
@@ -48,8 +48,8 @@ object PostTyperTransformers {
reorder0(stats)
}
- override def transformStats(trees: List[tpd.Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[tpd.Tree] =
- super.transformStats(reorder(trees)(ctx, info), info, current)
+ override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol, info: TransformerInfo, current: Int)(implicit ctx: Context): List[tpd.Tree] =
+ super.transformStats(reorder(trees)(ctx, info), exprOwner, info, current)
override def transform(tree: tpd.Tree, info: TransformerInfo, cur: Int)(implicit ctx: Context): tpd.Tree = tree match {
case tree: Import => EmptyTree
diff --git a/src/dotty/tools/dotc/transform/Splitter.scala b/src/dotty/tools/dotc/transform/Splitter.scala
new file mode 100644
index 000000000..9c01574aa
--- /dev/null
+++ b/src/dotty/tools/dotc/transform/Splitter.scala
@@ -0,0 +1,27 @@
+package dotty.tools.dotc
+package transform
+
+import TreeTransforms._
+import ast.Trees._
+import core.Contexts._
+import core.Types._
+
+/** This transform makes usre every identifier and select node
+ * carries a symbol. To do this, certain qualifiers with a union type
+ * have to be "splitted" with a type test.
+ *
+ * For now, only self references are treated.
+ */
+class Splitter extends TreeTransform {
+ import ast.tpd._
+
+ override def name: String = "splitter"
+
+ /** Replace self referencing idents with ThisTypes. */
+ override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo) = tree.tpe match {
+ case ThisType(cls) =>
+ println(s"owner = ${ctx.owner}, context = ${ctx}")
+ This(cls) withPos tree.pos
+ case _ => tree
+ }
+} \ No newline at end of file
diff --git a/src/dotty/tools/dotc/transform/TreeTransform.scala b/src/dotty/tools/dotc/transform/TreeTransform.scala
index 684714199..425410ae7 100644
--- a/src/dotty/tools/dotc/transform/TreeTransform.scala
+++ b/src/dotty/tools/dotc/transform/TreeTransform.scala
@@ -3,7 +3,9 @@ package dotty.tools.dotc.transform
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Phases.Phase
+import dotty.tools.dotc.core.Symbols.Symbol
import dotty.tools.dotc.ast.Trees._
+import dotty.tools.dotc.core.Decorators._
import scala.annotation.tailrec
object TreeTransforms {
@@ -149,9 +151,7 @@ object TreeTransforms {
type Mutator[T] = (TreeTransform, T, Context) => TreeTransform
- class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer, val contexts:Array[Context]) {
- assert(transformers.size == contexts.size)
- }
+ class TransformerInfo(val transformers: Array[TreeTransform], val nx: NXTransformations, val group:TreeTransformer)
/**
* This class maintains track of which methods are redefined in MiniPhases and creates execution plans for transformXXX and prepareXXX
@@ -431,15 +431,15 @@ object TreeTransforms {
val l = result.length
var allDone = i < l
while (i < l) {
- val oldT = result(i)
- val newT = mutator(oldT, tree, info.contexts(i))
- allDone = allDone && (newT eq NoTransform)
- if (!(oldT eq newT)) {
+ val oldTransform = result(i)
+ val newTransform = mutator(oldTransform, tree, ctx.withPhase(oldTransform))
+ allDone = allDone && (newTransform eq NoTransform)
+ if (!(oldTransform eq newTransform)) {
if (!transformersCopied) result = result.clone()
transformersCopied = true
- result(i) = newT
- if (!(newT.getClass == oldT.getClass)) {
- resultNX = new NXTransformations(resultNX, newT, i, nxCopied)
+ result(i) = newTransform
+ if (!(newTransform.getClass == oldTransform.getClass)) {
+ resultNX = new NXTransformations(resultNX, newTransform, i, nxCopied)
nxCopied = true
}
}
@@ -447,7 +447,7 @@ object TreeTransforms {
}
if (allDone) null
else if (!transformersCopied) info
- else new TransformerInfo(result, resultNX, info.group, info.contexts)
+ else new TransformerInfo(result, resultNX, info.group)
}
val prepForIdent: Mutator[Ident] = (trans, tree, ctx) => trans.prepareForIdent(tree)(ctx)
@@ -484,8 +484,7 @@ object TreeTransforms {
def transform(t: Tree)(implicit ctx: Context): Tree = {
val initialTransformations = transformations
- val contexts = initialTransformations.map(tr => ctx.withPhase(tr).ctx)
- val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this, contexts)
+ val info = new TransformerInfo(initialTransformations, new NXTransformations(initialTransformations), this)
initialTransformations.zipWithIndex.foreach{
case (transform, id) =>
transform.idx = id
@@ -498,8 +497,7 @@ object TreeTransforms {
final private[TreeTransforms] def goIdent(tree: Ident, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
-
- trans.transformIdent(tree)(info.contexts(cur), info) match {
+ trans.transformIdent(tree)(ctx.withPhase(trans), info) match {
case t: Ident => goIdent(t, info.nx.nxTransIdent(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -510,7 +508,7 @@ object TreeTransforms {
final private[TreeTransforms] def goSelect(tree: Select, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformSelect(tree)(info.contexts(cur), info) match {
+ trans.transformSelect(tree)(ctx.withPhase(trans), info) match {
case t: Select => goSelect(t, info.nx.nxTransSelect(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -521,7 +519,7 @@ object TreeTransforms {
final private[TreeTransforms] def goThis(tree: This, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformThis(tree)(info.contexts(cur), info) match {
+ trans.transformThis(tree)(ctx.withPhase(trans), info) match {
case t: This => goThis(t, info.nx.nxTransThis(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -532,7 +530,7 @@ object TreeTransforms {
final private[TreeTransforms] def goSuper(tree: Super, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformSuper(tree)(info.contexts(cur), info) match {
+ trans.transformSuper(tree)(ctx.withPhase(trans), info) match {
case t: Super => goSuper(t, info.nx.nxTransSuper(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -543,7 +541,7 @@ object TreeTransforms {
final private[TreeTransforms] def goApply(tree: Apply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformApply(tree)(info.contexts(cur), info) match {
+ trans.transformApply(tree)(ctx.withPhase(trans), info) match {
case t: Apply => goApply(t, info.nx.nxTransApply(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -554,7 +552,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTypeApply(tree: TypeApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTypeApply(tree)(info.contexts(cur), info) match {
+ trans.transformTypeApply(tree)(ctx.withPhase(trans), info) match {
case t: TypeApply => goTypeApply(t, info.nx.nxTransTypeApply(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -565,7 +563,7 @@ object TreeTransforms {
final private[TreeTransforms] def goNew(tree: New, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformNew(tree)(info.contexts(cur), info) match {
+ trans.transformNew(tree)(ctx.withPhase(trans), info) match {
case t: New => goNew(t, info.nx.nxTransNew(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -576,7 +574,7 @@ object TreeTransforms {
final private[TreeTransforms] def goPair(tree: Pair, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformPair(tree)(info.contexts(cur), info) match {
+ trans.transformPair(tree)(ctx.withPhase(trans), info) match {
case t: Pair => goPair(t, info.nx.nxTransPair(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -587,7 +585,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTyped(tree: Typed, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTyped(tree)(info.contexts(cur), info) match {
+ trans.transformTyped(tree)(ctx.withPhase(trans), info) match {
case t: Typed => goTyped(t, info.nx.nxTransTyped(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -598,7 +596,7 @@ object TreeTransforms {
final private[TreeTransforms] def goAssign(tree: Assign, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformAssign(tree)(info.contexts(cur), info) match {
+ trans.transformAssign(tree)(ctx.withPhase(trans), info) match {
case t: Assign => goAssign(t, info.nx.nxTransAssign(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -609,7 +607,7 @@ object TreeTransforms {
final private[TreeTransforms] def goLiteral(tree: Literal, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformLiteral(tree)(info.contexts(cur), info) match {
+ trans.transformLiteral(tree)(ctx.withPhase(trans), info) match {
case t: Literal => goLiteral(t, info.nx.nxTransLiteral(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -620,7 +618,7 @@ object TreeTransforms {
final private[TreeTransforms] def goBlock(tree: Block, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformBlock(tree)(info.contexts(cur), info) match {
+ trans.transformBlock(tree)(ctx.withPhase(trans), info) match {
case t: Block => goBlock(t, info.nx.nxTransBlock(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -631,7 +629,7 @@ object TreeTransforms {
final private[TreeTransforms] def goIf(tree: If, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformIf(tree)(info.contexts(cur), info) match {
+ trans.transformIf(tree)(ctx.withPhase(trans), info) match {
case t: If => goIf(t, info.nx.nxTransIf(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -642,7 +640,7 @@ object TreeTransforms {
final private[TreeTransforms] def goClosure(tree: Closure, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformClosure(tree)(info.contexts(cur), info) match {
+ trans.transformClosure(tree)(ctx.withPhase(trans), info) match {
case t: Closure => goClosure(t, info.nx.nxTransClosure(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -653,7 +651,7 @@ object TreeTransforms {
final private[TreeTransforms] def goMatch(tree: Match, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformMatch(tree)(info.contexts(cur), info) match {
+ trans.transformMatch(tree)(ctx.withPhase(trans), info) match {
case t: Match => goMatch(t, info.nx.nxTransMatch(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -664,7 +662,7 @@ object TreeTransforms {
final private[TreeTransforms] def goCaseDef(tree: CaseDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformCaseDef(tree)(info.contexts(cur), info) match {
+ trans.transformCaseDef(tree)(ctx.withPhase(trans), info) match {
case t: CaseDef => goCaseDef(t, info.nx.nxTransCaseDef(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -675,7 +673,7 @@ object TreeTransforms {
final private[TreeTransforms] def goReturn(tree: Return, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformReturn(tree)(info.contexts(cur), info) match {
+ trans.transformReturn(tree)(ctx.withPhase(trans), info) match {
case t: Return => goReturn(t, info.nx.nxTransReturn(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -686,7 +684,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTry(tree: Try, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTry(tree)(info.contexts(cur), info) match {
+ trans.transformTry(tree)(ctx.withPhase(trans), info) match {
case t: Try => goTry(t, info.nx.nxTransTry(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -697,7 +695,7 @@ object TreeTransforms {
final private[TreeTransforms] def goThrow(tree: Throw, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformThrow(tree)(info.contexts(cur), info) match {
+ trans.transformThrow(tree)(ctx.withPhase(trans), info) match {
case t: Throw => goThrow(t, info.nx.nxTransThrow(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -708,7 +706,7 @@ object TreeTransforms {
final private[TreeTransforms] def goSeqLiteral(tree: SeqLiteral, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformSeqLiteral(tree)(info.contexts(cur), info) match {
+ trans.transformSeqLiteral(tree)(ctx.withPhase(trans), info) match {
case t: SeqLiteral => goSeqLiteral(t, info.nx.nxTransSeqLiteral(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -719,7 +717,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTypeTree(tree: TypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTypeTree(tree)(info.contexts(cur), info) match {
+ trans.transformTypeTree(tree)(ctx.withPhase(trans), info) match {
case t: TypeTree => goTypeTree(t, info.nx.nxTransTypeTree(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -730,7 +728,7 @@ object TreeTransforms {
final private[TreeTransforms] def goSelectFromTypeTree(tree: SelectFromTypeTree, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformSelectFromTypeTree(tree)(info.contexts(cur), info) match {
+ trans.transformSelectFromTypeTree(tree)(ctx.withPhase(trans), info) match {
case t: SelectFromTypeTree => goSelectFromTypeTree(t, info.nx.nxTransSelectFromTypeTree(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -741,7 +739,7 @@ object TreeTransforms {
final private[TreeTransforms] def goBind(tree: Bind, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformBind(tree)(info.contexts(cur), info) match {
+ trans.transformBind(tree)(ctx.withPhase(trans), info) match {
case t: Bind => goBind(t, info.nx.nxTransBind(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -752,7 +750,7 @@ object TreeTransforms {
final private[TreeTransforms] def goAlternative(tree: Alternative, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformAlternative(tree)(info.contexts(cur), info) match {
+ trans.transformAlternative(tree)(ctx.withPhase(trans), info) match {
case t: Alternative => goAlternative(t, info.nx.nxTransAlternative(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -763,7 +761,7 @@ object TreeTransforms {
final private[TreeTransforms] def goValDef(tree: ValDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformValDef(tree)(info.contexts(cur), info) match {
+ trans.transformValDef(tree)(ctx.withPhase(trans), info) match {
case t: ValDef => goValDef(t, info.nx.nxTransValDef(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -774,7 +772,7 @@ object TreeTransforms {
final private[TreeTransforms] def goDefDef(tree: DefDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformDefDef(tree)(info.contexts(cur), info) match {
+ trans.transformDefDef(tree)(ctx.withPhase(trans), info) match {
case t: DefDef => goDefDef(t, info.nx.nxTransDefDef(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -785,7 +783,7 @@ object TreeTransforms {
final private[TreeTransforms] def goUnApply(tree: UnApply, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformUnApply(tree)(info.contexts(cur), info) match {
+ trans.transformUnApply(tree)(ctx.withPhase(trans), info) match {
case t: UnApply => goUnApply(t, info.nx.nxTransUnApply(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -796,7 +794,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTypeDef(tree: TypeDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTypeDef(tree)(info.contexts(cur), info) match {
+ trans.transformTypeDef(tree)(ctx.withPhase(trans), info) match {
case t: TypeDef => goTypeDef(t, info.nx.nxTransTypeDef(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -807,7 +805,7 @@ object TreeTransforms {
final private[TreeTransforms] def goTemplate(tree: Template, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformTemplate(tree)(info.contexts(cur), info) match {
+ trans.transformTemplate(tree)(ctx.withPhase(trans), info) match {
case t: Template => goTemplate(t, info.nx.nxTransTemplate(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -818,7 +816,7 @@ object TreeTransforms {
final private[TreeTransforms] def goPackageDef(tree: PackageDef, cur: Int)(implicit ctx: Context, info: TransformerInfo): Tree = {
if (cur < info.transformers.length) {
val trans = info.transformers(cur)
- trans.transformPackageDef(tree)(info.contexts(cur), info) match {
+ trans.transformPackageDef(tree)(ctx.withPhase(trans), info) match {
case t: PackageDef => goPackageDef(t, info.nx.nxTransPackageDef(cur + 1))
case t => transformSingle(t, cur + 1)
}
@@ -863,9 +861,7 @@ object TreeTransforms {
case tree: UnApply => goUnApply(tree, info.nx.nxTransUnApply(cur))
case tree: Template => goTemplate(tree, info.nx.nxTransTemplate(cur))
case tree: PackageDef => goPackageDef(tree, info.nx.nxTransPackageDef(cur))
- case Thicket(trees) if trees != Nil =>
- val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur)
- if (trees1 eq trees) tree else Thicket(trees1)
+ case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur))
case tree => tree
}
@@ -876,7 +872,9 @@ object TreeTransforms {
case tree => goUnamed(tree, cur)
}
- final private[TreeTransforms] def transformNameTree(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree =
+ def localContext(owner: Symbol)(implicit ctx: Context) = ctx.fresh.setOwner(owner)
+
+ final private[TreeTransforms] def transformNamed(tree: NameTree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree =
tree match {
case tree: Ident =>
implicit val mutatedInfo = mutateTransformers(info, prepForIdent, info.nx.nxPrepIdent, tree, cur)
@@ -907,25 +905,27 @@ object TreeTransforms {
implicit val mutatedInfo = mutateTransformers(info, prepForValDef, info.nx.nxPrepValDef, tree, cur)
if (mutatedInfo eq null) tree
else {
- val tpt = transform(tree.tpt, mutatedInfo, cur)
- val rhs = transform(tree.rhs, mutatedInfo, cur)
+ val nestedCtx = if (tree.symbol.exists) localContext(tree.symbol) else ctx
+ val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx)
+ val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx)
goValDef(cpy.ValDef(tree, tree.mods, tree.name, tpt, rhs), mutatedInfo.nx.nxTransValDef(cur))
}
case tree: DefDef =>
implicit val mutatedInfo = mutateTransformers(info, prepForDefDef, info.nx.nxPrepDefDef, tree, cur)
if (mutatedInfo eq null) tree
else {
- val tparams = transformSubL(tree.tparams, mutatedInfo, cur)
- val vparams = tree.vparamss.mapConserve(x => transformSubL(x, mutatedInfo, cur))
- val tpt = transform(tree.tpt, mutatedInfo, cur)
- val rhs = transform(tree.rhs, mutatedInfo, cur)
+ val nestedCtx = localContext(tree.symbol)
+ val tparams = transformSubTrees(tree.tparams, mutatedInfo, cur)(nestedCtx)
+ val vparams = tree.vparamss.mapConserve(x => transformSubTrees(x, mutatedInfo, cur)(nestedCtx))
+ val tpt = transform(tree.tpt, mutatedInfo, cur)(nestedCtx)
+ val rhs = transform(tree.rhs, mutatedInfo, cur)(nestedCtx)
goDefDef(cpy.DefDef(tree, tree.mods, tree.name, tparams, vparams, tpt, rhs), mutatedInfo.nx.nxTransDefDef(cur))
}
case tree: TypeDef =>
implicit val mutatedInfo = mutateTransformers(info, prepForTypeDef, info.nx.nxPrepTypeDef, tree, cur)
if (mutatedInfo eq null) tree
else {
- val rhs = transform(tree.rhs, mutatedInfo, cur)
+ val rhs = transform(tree.rhs, mutatedInfo, cur)(localContext(tree.symbol))
goTypeDef(cpy.TypeDef(tree, tree.mods, tree.name, rhs, tree.tparams), mutatedInfo.nx.nxTransTypeDef(cur))
}
case _ =>
@@ -950,7 +950,7 @@ object TreeTransforms {
if (mutatedInfo eq null) tree
else {
val fun = transform(tree.fun, mutatedInfo, cur)
- val args = transformSubL(tree.args, mutatedInfo, cur)
+ val args = transformSubTrees(tree.args, mutatedInfo, cur)
goApply(cpy.Apply(tree, fun, args), mutatedInfo.nx.nxTransApply(cur))
}
case tree: TypeApply =>
@@ -958,7 +958,7 @@ object TreeTransforms {
if (mutatedInfo eq null) tree
else {
val fun = transform(tree.fun, mutatedInfo, cur)
- val args = transformL(tree.args, mutatedInfo, cur)
+ val args = transformTrees(tree.args, mutatedInfo, cur)
goTypeApply(cpy.TypeApply(tree, fun, args), mutatedInfo.nx.nxTransTypeApply(cur))
}
case tree: Literal =>
@@ -1000,7 +1000,7 @@ object TreeTransforms {
implicit val mutatedInfo = mutateTransformers(info, prepForBlock, info.nx.nxPrepBlock, tree, cur)
if (mutatedInfo eq null) tree
else {
- val stats = transformStats(tree.stats, mutatedInfo, cur)
+ val stats = transformStats(tree.stats, ctx.owner, mutatedInfo, cur)
val expr = transform(tree.expr, mutatedInfo, cur)
goBlock(cpy.Block(tree, stats, expr), mutatedInfo.nx.nxTransBlock(cur))
}
@@ -1017,7 +1017,7 @@ object TreeTransforms {
implicit val mutatedInfo = mutateTransformers(info, prepForClosure, info.nx.nxPrepClosure, tree, cur)
if (mutatedInfo eq null) tree
else {
- val env = transformL(tree.env, mutatedInfo, cur)
+ val env = transformTrees(tree.env, mutatedInfo, cur)
val meth = transform(tree.meth, mutatedInfo, cur)
val tpt = transform(tree.tpt, mutatedInfo, cur)
goClosure(cpy.Closure(tree, env, meth, tpt), mutatedInfo.nx.nxTransClosure(cur))
@@ -1027,7 +1027,7 @@ object TreeTransforms {
if (mutatedInfo eq null) tree
else {
val selector = transform(tree.selector, mutatedInfo, cur)
- val cases = transformSubL(tree.cases, mutatedInfo, cur)
+ val cases = transformSubTrees(tree.cases, mutatedInfo, cur)
goMatch(cpy.Match(tree, selector, cases), mutatedInfo.nx.nxTransMatch(cur))
}
case tree: CaseDef =>
@@ -1067,7 +1067,7 @@ object TreeTransforms {
implicit val mutatedInfo = mutateTransformers(info, prepForSeqLiteral, info.nx.nxPrepSeqLiteral, tree, cur)
if (mutatedInfo eq null) tree
else {
- val elems = transformL(tree.elems, mutatedInfo, cur)
+ val elems = transformTrees(tree.elems, mutatedInfo, cur)
goSeqLiteral(cpy.SeqLiteral(tree, elems), mutatedInfo.nx.nxTransLiteral(cur))
}
case tree: TypeTree =>
@@ -1081,7 +1081,7 @@ object TreeTransforms {
implicit val mutatedInfo = mutateTransformers(info, prepForAlternative, info.nx.nxPrepAlternative, tree, cur)
if (mutatedInfo eq null) tree
else {
- val trees = transformL(tree.trees, mutatedInfo, cur)
+ val trees = transformTrees(tree.trees, mutatedInfo, cur)
goAlternative(cpy.Alternative(tree, trees), mutatedInfo.nx.nxTransAlternative(cur))
}
case tree: UnApply =>
@@ -1089,8 +1089,8 @@ object TreeTransforms {
if (mutatedInfo eq null) tree
else {
val fun = transform(tree.fun, mutatedInfo, cur)
- val implicits = transformL(tree.implicits, mutatedInfo, cur)
- val patterns = transformL(tree.patterns, mutatedInfo, cur)
+ val implicits = transformTrees(tree.implicits, mutatedInfo, cur)
+ val patterns = transformTrees(tree.patterns, mutatedInfo, cur)
goUnApply(cpy.UnApply(tree, fun, implicits, patterns), mutatedInfo.nx.nxTransUnApply(cur))
}
case tree: Template =>
@@ -1098,29 +1098,28 @@ object TreeTransforms {
if (mutatedInfo eq null) tree
else {
val constr = transformSub(tree.constr, mutatedInfo, cur)
- val parents = transformL(tree.parents, mutatedInfo, cur)
+ val parents = transformTrees(tree.parents, mutatedInfo, cur)
val self = transformSub(tree.self, mutatedInfo, cur)
- val body = transformStats(tree.body, mutatedInfo, cur)
+ val body = transformStats(tree.body, tree.symbol, mutatedInfo, cur)
goTemplate(cpy.Template(tree, constr, parents, self, body), mutatedInfo.nx.nxTransTemplate(cur))
}
case tree: PackageDef =>
implicit val mutatedInfo = mutateTransformers(info, prepForPackageDef, info.nx.nxPrepPackageDef, tree, cur)
if (mutatedInfo eq null) tree
else {
+ val nestedCtx = localContext(tree.symbol)
val pid = transformSub(tree.pid, mutatedInfo, cur)
- val stats = transformStats(tree.stats, mutatedInfo, cur)
+ val stats = transformStats(tree.stats, tree.symbol, mutatedInfo, cur)(nestedCtx)
goPackageDef(cpy.PackageDef(tree, pid, stats), mutatedInfo.nx.nxTransPackageDef(cur))
}
- case Thicket(trees) if trees != Nil =>
- val trees1 = transformL(trees.asInstanceOf[List[tpd.Tree]], info, cur)
- if (trees1 eq trees) tree else Thicket(trees1)
+ case Thicket(trees) => cpy.Thicket(tree, transformTrees(trees, info, cur))
case tree => tree
}
def transform(tree: Tree, info: TransformerInfo, cur: Int)(implicit ctx: Context): Tree = {
tree match {
//split one big match into 2 smaller ones
- case tree: NameTree => transformNameTree(tree, info, cur)
+ case tree: NameTree => transformNamed(tree, info, cur)
case tree => transformUnnamed(tree, info, cur)
}
}
@@ -1134,20 +1133,26 @@ object TreeTransforms {
} else trees
}
- def transformStats(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = {
+ def transformStats(trees: List[Tree], exprOwner: Symbol, info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] = {
val newInfo = mutateTransformers(info, prepForStats, info.nx.nxPrepStats, trees, current)
- val newTrees = transformL(trees, newInfo, current)(ctx)
- flatten(goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo))
+ val exprCtx = ctx.withOwner(exprOwner)
+ def transformStat(stat: Tree): Tree = stat match {
+ case _: Import | _: DefTree => transform(stat, info, current)
+ case Thicket(stats) => cpy.Thicket(stat, stats mapConserve transformStat)
+ case _ => transform(stat, info, current)(exprCtx)
+ }
+ val newTrees = flatten(trees.mapconserve(transformStat))
+ goStats(newTrees, newInfo.nx.nxTransStats(current))(ctx, newInfo)
}
- def transformL(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] =
+ def transformTrees(trees: List[Tree], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tree] =
flatten(trees mapConserve (x => transform(x, info, current)))
def transformSub[Tr <: Tree](tree: Tr, info: TransformerInfo, current: Int)(implicit ctx: Context): Tr =
transform(tree, info, current).asInstanceOf[Tr]
- def transformSubL[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] =
- transformL(trees, info, current)(ctx).asInstanceOf[List[Tr]]
+ def transformSubTrees[Tr <: Tree](trees: List[Tr], info: TransformerInfo, current: Int)(implicit ctx: Context): List[Tr] =
+ transformTrees(trees, info, current)(ctx).asInstanceOf[List[Tr]]
}
}