diff options
-rw-r--r-- | src/reflect/scala/reflect/api/Trees.scala | 28 | ||||
-rw-r--r-- | src/reflect/scala/reflect/internal/Trees.scala | 173 |
2 files changed, 122 insertions, 79 deletions
diff --git a/src/reflect/scala/reflect/api/Trees.scala b/src/reflect/scala/reflect/api/Trees.scala index fa7d41f0fc..241747e6d8 100644 --- a/src/reflect/scala/reflect/api/Trees.scala +++ b/src/reflect/scala/reflect/api/Trees.scala @@ -2544,18 +2544,32 @@ trait Trees { self: Universe => class Traverser { protected[scala] var currentOwner: Symbol = rootMirror.RootClass + /** Traverse something which Trees contain, but which isn't a Tree itself. */ + def traverseName(name: Name): Unit = () + def traverseConstant(c: Constant): Unit = () + def traverseImportSelector(sel: ImportSelector): Unit = () + def traverseModifiers(mods: Modifiers): Unit = traverseAnnotations(mods.annotations) + /** Traverses a single tree. */ - def traverse(tree: Tree): Unit = itraverse(this, tree) + def traverse(tree: Tree): Unit = itraverse(this, tree) + def traversePattern(pat: Tree): Unit = traverse(pat) + def traverseGuard(guard: Tree): Unit = traverse(guard) + def traverseTypeAscription(tpt: Tree): Unit = traverse(tpt) + // Special handling of noSelfType necessary for backward compat: existing + // traversers break down when they see the unexpected tree. + def traverseSelfType(self: ValDef): Unit = if (self ne noSelfType) traverse(self) /** Traverses a list of trees. */ - def traverseTrees(trees: List[Tree]) { - trees foreach traverse - } + def traverseTrees(trees: List[Tree]): Unit = trees foreach traverse + def traverseTypeArgs(args: List[Tree]): Unit = traverseTrees(args) + def traverseParents(parents: List[Tree]): Unit = traverseTrees(parents) + def traverseCases(cases: List[CaseDef]): Unit = traverseTrees(cases) + def traverseAnnotations(annots: List[Tree]): Unit = traverseTrees(annots) /** Traverses a list of lists of trees. */ - def traverseTreess(treess: List[List[Tree]]) { - treess foreach traverseTrees - } + def traverseTreess(treess: List[List[Tree]]): Unit = treess foreach traverseTrees + def traverseParams(params: List[Tree]): Unit = traverseTrees(params) + def traverseParamss(vparamss: List[List[Tree]]): Unit = vparamss foreach traverseParams /** Traverses a list of trees with a given owner symbol. */ def traverseStats(stats: List[Tree], exprOwner: Symbol) { diff --git a/src/reflect/scala/reflect/internal/Trees.scala b/src/reflect/scala/reflect/internal/Trees.scala index ccf5cf5249..bc1a111faa 100644 --- a/src/reflect/scala/reflect/internal/Trees.scala +++ b/src/reflect/scala/reflect/internal/Trees.scala @@ -8,10 +8,12 @@ package reflect package internal import Flags._ -import scala.collection.mutable.{ListBuffer, LinkedHashSet} +import pickling.PickleFormat._ +import scala.collection.{ mutable, immutable } import util.Statistics -trait Trees extends api.Trees { self: SymbolTable => +trait Trees extends api.Trees { + self: SymbolTable => private[scala] var nodeCount = 0 @@ -160,7 +162,7 @@ trait Trees extends api.Trees { self: SymbolTable => override def freeTypes: List[FreeTypeSymbol] = freeSyms[FreeTypeSymbol](_.isFreeType, _.typeSymbol) private def freeSyms[S <: Symbol](isFree: Symbol => Boolean, symOfType: Type => Symbol): List[S] = { - val s = scala.collection.mutable.LinkedHashSet[S]() + val s = mutable.LinkedHashSet[S]() def addIfFree(sym: Symbol): Unit = if (sym != null && isFree(sym)) s += sym.asInstanceOf[S] for (t <- this) { addIfFree(t.symbol) @@ -1019,18 +1021,22 @@ trait Trees extends api.Trees { self: SymbolTable => } trait CannotHaveAttrs extends Tree { - override def canHaveAttrs = false - - private def requireLegal(value: Any, allowed: Any, what: String) = - require(value == allowed, s"can't set $what for $self to value other than $allowed") - super.setPos(NoPosition) + super.setType(NoType) + + override def canHaveAttrs = false override def setPos(pos: Position) = { requireLegal(pos, NoPosition, "pos"); this } override def pos_=(pos: Position) = setPos(pos) - - super.setType(NoType) override def setType(t: Type) = { requireLegal(t, NoType, "tpe"); this } override def tpe_=(t: Type) = setType(t) + + private def requireLegal(value: Any, allowed: Any, what: String) = ( + if (value != allowed) { + log(s"can't set $what for $self to value other than $allowed") + if (settings.debug && settings.developer) + (new Throwable).printStackTrace + } + ) } case object EmptyTree extends TermTree with CannotHaveAttrs { override def isEmpty = true; val asList = List(this) } @@ -1159,113 +1165,136 @@ trait Trees extends api.Trees { self: SymbolTable => override protected def itraverse(traverser: Traverser, tree: Tree): Unit = { import traverser._ - tree match { - case EmptyTree => - ; - case PackageDef(pid, stats) => - traverse(pid) - atOwner(mclass(tree.symbol)) { - traverseTrees(stats) - } - case ClassDef(mods, name, tparams, impl) => - atOwner(tree.symbol) { - traverseTrees(mods.annotations); traverseTrees(tparams); traverse(impl) - } - case ModuleDef(mods, name, impl) => - atOwner(mclass(tree.symbol)) { - traverseTrees(mods.annotations); traverse(impl) - } - case ValDef(mods, name, tpt, rhs) => - atOwner(tree.symbol) { - traverseTrees(mods.annotations); traverse(tpt); traverse(rhs) - } - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => - atOwner(tree.symbol) { - traverseTrees(mods.annotations); traverseTrees(tparams); traverseTreess(vparamss); traverse(tpt); traverse(rhs) - } - case TypeDef(mods, name, tparams, rhs) => - atOwner(tree.symbol) { - traverseTrees(mods.annotations); traverseTrees(tparams); traverse(rhs) - } + + def traverseMemberDef(md: MemberDef, owner: Symbol): Unit = atOwner(owner) { + traverseModifiers(md.mods) + traverseName(md.name) + md match { + case ClassDef(_, _, tparams, impl) => traverseParams(tparams) ; traverse(impl) + case ModuleDef(_, _, impl) => traverse(impl) + case ValDef(_, _, tpt, rhs) => traverseTypeAscription(tpt) ; traverse(rhs) + case TypeDef(_, _, tparams, rhs) => traverseParams(tparams) ; traverse(rhs) + case DefDef(_, _, tparams, vparamss, tpt, rhs) => + traverseParams(tparams) + traverseParamss(vparamss) + traverseTypeAscription(tpt) + traverse(rhs) + } + } + def traverseComponents(): Unit = tree match { case LabelDef(name, params, rhs) => - traverseTrees(params); traverse(rhs) + traverseName(name) + traverseParams(params) + traverse(rhs) case Import(expr, selectors) => traverse(expr) + selectors foreach traverseImportSelector case Annotated(annot, arg) => - traverse(annot); traverse(arg) + traverse(annot) + traverse(arg) case Template(parents, self, body) => - traverseTrees(parents) - if (self ne noSelfType) traverse(self) + traverseParents(parents) + traverseSelfType(self) traverseStats(body, tree.symbol) case Block(stats, expr) => - traverseTrees(stats); traverse(expr) + traverseTrees(stats) + traverse(expr) case CaseDef(pat, guard, body) => - traverse(pat); traverse(guard); traverse(body) + traversePattern(pat) + traverseGuard(guard) + traverse(body) case Alternative(trees) => traverseTrees(trees) case Star(elem) => traverse(elem) case Bind(name, body) => + traverseName(name) traverse(body) case UnApply(fun, args) => - traverse(fun); traverseTrees(args) + traverse(fun) + traverseTrees(args) case ArrayValue(elemtpt, trees) => - traverse(elemtpt); traverseTrees(trees) - case Function(vparams, body) => - atOwner(tree.symbol) { - traverseTrees(vparams); traverse(body) - } + traverse(elemtpt) + traverseTrees(trees) case Assign(lhs, rhs) => - traverse(lhs); traverse(rhs) + traverse(lhs) + traverse(rhs) case AssignOrNamedArg(lhs, rhs) => - traverse(lhs); traverse(rhs) + traverse(lhs) + traverse(rhs) case If(cond, thenp, elsep) => - traverse(cond); traverse(thenp); traverse(elsep) + traverse(cond) + traverse(thenp) + traverse(elsep) case Match(selector, cases) => - traverse(selector); traverseTrees(cases) + traverse(selector) + traverseCases(cases) case Return(expr) => traverse(expr) case Try(block, catches, finalizer) => - traverse(block); traverseTrees(catches); traverse(finalizer) + traverse(block) + traverseCases(catches) + traverse(finalizer) case Throw(expr) => traverse(expr) case New(tpt) => traverse(tpt) case Typed(expr, tpt) => - traverse(expr); traverse(tpt) + traverse(expr) + traverseTypeAscription(tpt) case TypeApply(fun, args) => - traverse(fun); traverseTrees(args) + traverse(fun) + traverseTypeArgs(args) case Apply(fun, args) => - traverse(fun); traverseTrees(args) + traverse(fun) + traverseTrees(args) case ApplyDynamic(qual, args) => - traverse(qual); traverseTrees(args) - case Super(qual, _) => traverse(qual) - case This(_) => - ; + traverseTrees(args) + case Super(qual, mix) => + traverse(qual) + traverseName(mix) + case This(qual) => + traverseName(qual) case Select(qualifier, selector) => traverse(qualifier) - case Ident(_) => - ; + traverseName(selector) + case Ident(name) => + traverseName(name) case ReferenceToBoxed(idt) => traverse(idt) - case Literal(_) => - ; + case Literal(const) => + traverseConstant(const) case TypeTree() => ; case SingletonTypeTree(ref) => traverse(ref) case SelectFromTypeTree(qualifier, selector) => traverse(qualifier) + traverseName(selector) case CompoundTypeTree(templ) => traverse(templ) case AppliedTypeTree(tpt, args) => - traverse(tpt); traverseTrees(args) + traverse(tpt) + traverseTypeArgs(args) case TypeBoundsTree(lo, hi) => - traverse(lo); traverse(hi) + traverse(lo) + traverse(hi) case ExistentialTypeTree(tpt, whereClauses) => - traverse(tpt); traverseTrees(whereClauses) - case _ => xtraverse(traverser, tree) + traverse(tpt) + traverseTrees(whereClauses) + case _ => + xtraverse(traverser, tree) + } + + if (tree.canHaveAttrs) { + tree match { + case PackageDef(pid, stats) => traverse(pid) ; traverseStats(stats, mclass(tree.symbol)) + case md: ModuleDef => traverseMemberDef(md, mclass(tree.symbol)) + case md: MemberDef => traverseMemberDef(md, tree.symbol) + case Function(vparams, body) => atOwner(tree.symbol) { traverseParams(vparams) ; traverse(body) } + case _ => traverseComponents() + } } } @@ -1563,7 +1592,7 @@ trait Trees extends api.Trees { self: SymbolTable => } class FilterTreeTraverser(p: Tree => Boolean) extends Traverser { - val hits = new ListBuffer[Tree] + val hits = mutable.ListBuffer[Tree]() override def traverse(t: Tree) { if (p(t)) hits += t super.traverse(t) @@ -1571,7 +1600,7 @@ trait Trees extends api.Trees { self: SymbolTable => } class CollectTreeTraverser[T](pf: PartialFunction[Tree, T]) extends Traverser { - val results = new ListBuffer[T] + val results = mutable.ListBuffer[T]() override def traverse(t: Tree) { if (pf.isDefinedAt(t)) results += pf(t) super.traverse(t) |