aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/dotc/ast/TreeInfo.scala33
-rw-r--r--src/dotty/tools/dotc/ast/tpd.scala3
-rw-r--r--test/dotty/tools/dotc/ast/TreeInfoTest.scala30
3 files changed, 44 insertions, 22 deletions
diff --git a/src/dotty/tools/dotc/ast/TreeInfo.scala b/src/dotty/tools/dotc/ast/TreeInfo.scala
index a1dd37e27..734963ea3 100644
--- a/src/dotty/tools/dotc/ast/TreeInfo.scala
+++ b/src/dotty/tools/dotc/ast/TreeInfo.scala
@@ -421,30 +421,19 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
* Pre: `sym` must have a position.
*/
def defPath(sym: Symbol, root: Tree)(implicit ctx: Context): List[Tree] = ctx.debugTraceIndented(s"defpath($sym with position ${sym.pos}, ${root.show})") {
- def show(from: Any): String = from match {
- case tree: Trees.Tree[_] => s"${tree.show} with attachments ${tree.allAttachments}"
- case x: printing.Showable => x.show
- case x => x.toString
- }
-
- def search(from: Any): List[Tree] = ctx.debugTraceIndented(s"search(${show(from)})") {
- from match {
- case tree: Tree => // Dotty problem: cannot write Tree @ unchecked, this currently gives a syntax error
- if (definedSym(tree) == sym) tree :: Nil
- else if (tree.envelope.contains(sym.pos)) {
- val p = search(tree.productIterator)
- if (p.isEmpty) p else tree :: p
- } else Nil
- case xs: Iterable[_] =>
- search(xs.iterator)
- case xs: Iterator[_] =>
- xs.map(search).find(_.nonEmpty).getOrElse(Nil)
- case _ =>
- Nil
+ require(sym.pos.exists)
+ object accum extends TreeAccumulator[List[Tree]] {
+ def apply(x: List[Tree], tree: Tree): List[Tree] = {
+ if (tree.envelope.contains(sym.pos))
+ if (definedSym(tree) == sym) tree :: x
+ else {
+ val x1 = foldOver(x, tree)
+ if (x1 ne x) tree :: x1 else x1
+ }
+ else x
}
}
- require(sym.pos.exists)
- search(root)
+ accum(Nil, root)
}
/** The statement sequence that contains a definition of `sym`, or Nil
diff --git a/src/dotty/tools/dotc/ast/tpd.scala b/src/dotty/tools/dotc/ast/tpd.scala
index 0dce4c324..4de98d8f8 100644
--- a/src/dotty/tools/dotc/ast/tpd.scala
+++ b/src/dotty/tools/dotc/ast/tpd.scala
@@ -330,6 +330,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def deepFold[T](z: T)(op: (T, tpd.Tree) => T) =
new DeepFolder(op).apply(z, tree)
+ def find[T](pred: (tpd.Tree) => Boolean): Option[tpd.Tree] =
+ shallowFold[Option[tpd.Tree]](None)((accum, tree) => if (pred(tree)) Some(tree) else accum)
+
def subst(from: List[Symbol], to: List[Symbol])(implicit ctx: Context): ThisTree =
new TreeMapper(typeMap = new ctx.SubstSymMap(from, to)).apply(tree)
diff --git a/test/dotty/tools/dotc/ast/TreeInfoTest.scala b/test/dotty/tools/dotc/ast/TreeInfoTest.scala
new file mode 100644
index 000000000..6e02ee813
--- /dev/null
+++ b/test/dotty/tools/dotc/ast/TreeInfoTest.scala
@@ -0,0 +1,30 @@
+package dotty.tools.dotc
+package ast
+
+import org.junit.Test
+import test.DottyTest
+import core.Names._
+import core.Types._
+import core.Symbols._
+import org.junit.Assert._
+
+class TreeInfoTest extends DottyTest {
+
+ import tpd._
+
+ @Test
+ def testDefPath = checkCompile("frontend", "class A { def bar = { val x = { val z = 0; 0} }} ") {
+ (tree, context) =>
+ implicit val ctx = context
+ val xTree = tree.find(tree => tree.symbol.name == termName("x")).get
+ val path = defPath(xTree.symbol, tree)
+ assertEquals(List(
+ ("PackageDef", EMPTY_PACKAGE),
+ ("TypeDef", typeName("A")),
+ ("Template", termName("<local A>")),
+ ("DefDef", termName("bar")),
+ ("Block", NoSymbol.name),
+ ("ValDef", termName("x"))
+ ), path.map(x => (x.productPrefix, x.symbol.name)))
+ }
+}