summaryrefslogtreecommitdiff
path: root/src/reflect/scala/reflect/internal/TreeGen.scala
diff options
context:
space:
mode:
authorDen Shabalin <den.shabalin@gmail.com>2013-11-05 13:43:24 +0100
committerDen Shabalin <den.shabalin@gmail.com>2013-11-12 14:04:42 +0100
commitef02d85ac6232665bac611d788d472665d15cade (patch)
tree7ee5573a2d3de950c45e27c9d16af6fa2f4da0fe /src/reflect/scala/reflect/internal/TreeGen.scala
parentd89bfbbaa4d86cd9ebd2dfd874ae4a3509533df0 (diff)
downloadscala-ef02d85ac6232665bac611d788d472665d15cade.tar.gz
scala-ef02d85ac6232665bac611d788d472665d15cade.tar.bz2
scala-ef02d85ac6232665bac611d788d472665d15cade.zip
move for loop desugaring into tree gen
Diffstat (limited to 'src/reflect/scala/reflect/internal/TreeGen.scala')
-rw-r--r--src/reflect/scala/reflect/internal/TreeGen.scala369
1 files changed, 369 insertions, 0 deletions
diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala
index b75368717f..d5eff8a9b8 100644
--- a/src/reflect/scala/reflect/internal/TreeGen.scala
+++ b/src/reflect/scala/reflect/internal/TreeGen.scala
@@ -466,6 +466,375 @@ abstract class TreeGen extends macros.TreeBuilder {
atPos(pkgPos)(PackageDef(pid, module :: Nil))
}
+ object ValFrom {
+ def apply(pat: Tree, rhs: Tree): Tree =
+ Apply(Ident(nme.LARROWkw).updateAttachment(ForAttachment),
+ List(pat, rhs))
+
+ def unapply(tree: Tree): Option[(Tree, Tree)] = tree match {
+ case Apply(id @ Ident(nme.LARROWkw), List(pat, rhs))
+ if id.hasAttachment[ForAttachment.type] =>
+ Some((pat, rhs))
+ case _ => None
+ }
+ }
+
+ object ValEq {
+ def apply(pat: Tree, rhs: Tree): Tree =
+ Assign(pat, rhs).updateAttachment(ForAttachment)
+
+ def unapply(tree: Tree): Option[(Tree, Tree)] = tree match {
+ case Assign(pat, rhs)
+ if tree.hasAttachment[ForAttachment.type] =>
+ Some((pat, rhs))
+ case _ => None
+ }
+ }
+
+ object Filter {
+ def apply(tree: Tree) =
+ Apply(Ident(nme.IFkw).updateAttachment(ForAttachment), List(tree))
+
+ def unapply(tree: Tree): Option[Tree] = tree match {
+ case Apply(id @ Ident(nme.IFkw), List(cond))
+ if id.hasAttachment[ForAttachment.type] =>
+ Some((cond))
+ case _ => None
+ }
+ }
+
+ object Yield {
+ def apply(tree: Tree): Tree =
+ Apply(Ident(nme.YIELDkw).updateAttachment(ForAttachment), List(tree))
+
+ def unapply(tree: Tree): Option[Tree] = tree match {
+ case Apply(id @ Ident(nme.YIELDkw), List(tree))
+ if id.hasAttachment[ForAttachment.type] =>
+ Some(tree)
+ case _ => None
+ }
+ }
+
+ /** Create tree for for-comprehension <for (enums) do body> or
+ * <for (enums) yield body> where mapName and flatMapName are chosen
+ * corresponding to whether this is a for-do or a for-yield.
+ * The creation performs the following rewrite rules:
+ *
+ * 1.
+ *
+ * for (P <- G) E ==> G.foreach (P => E)
+ *
+ * Here and in the following (P => E) is interpreted as the function (P => E)
+ * if P is a variable pattern and as the partial function { case P => E } otherwise.
+ *
+ * 2.
+ *
+ * for (P <- G) yield E ==> G.map (P => E)
+ *
+ * 3.
+ *
+ * for (P_1 <- G_1; P_2 <- G_2; ...) ...
+ * ==>
+ * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
+ *
+ * 4.
+ *
+ * for (P <- G; E; ...) ...
+ * =>
+ * for (P <- G.filter (P => E); ...) ...
+ *
+ * 5. For N < MaxTupleArity:
+ *
+ * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
+ * ==>
+ * for (TupleN(P_1, P_2, ... P_N) <-
+ * for (x_1 @ P_1 <- G) yield {
+ * val x_2 @ P_2 = E_2
+ * ...
+ * val x_N & P_N = E_N
+ * TupleN(x_1, ..., x_N)
+ * } ...)
+ *
+ * If any of the P_i are variable patterns, the corresponding `x_i @ P_i' is not generated
+ * and the variable constituting P_i is used instead of x_i
+ *
+ * @param mapName The name to be used for maps (either map or foreach)
+ * @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
+ * @param enums The enumerators in the for expression
+ * @param body The body of the for expression
+ */
+ def mkFor(enums: List[Tree], sugarBody: Tree)(implicit fresh: FreshNameCreator): Tree = {
+ val (mapName, flatMapName, body) = sugarBody match {
+ case Yield(tree) => (nme.map, nme.flatMap, tree)
+ case _ => (nme.foreach, nme.foreach, sugarBody)
+ }
+
+ /* make a closure pat => body.
+ * The closure is assigned a transparent position with the point at pos.point and
+ * the limits given by pat and body.
+ */
+ def makeClosure(pos: Position, pat: Tree, body: Tree): Tree = {
+ def splitpos = wrappingPos(List(pat, body)).withPoint(pos.point).makeTransparent
+ matchVarPattern(pat) match {
+ case Some((name, tpt)) =>
+ Function(
+ List(atPos(pat.pos) { ValDef(Modifiers(PARAM), name.toTermName, tpt, EmptyTree) }),
+ body) setPos splitpos
+ case None =>
+ atPos(splitpos) {
+ mkVisitor(List(CaseDef(pat, EmptyTree, body)), checkExhaustive = false)
+ }
+ }
+ }
+
+ /* Make an application qual.meth(pat => body) positioned at `pos`.
+ */
+ def makeCombination(pos: Position, meth: TermName, qual: Tree, pat: Tree, body: Tree): Tree =
+ Apply(Select(qual, meth) setPos qual.pos, List(makeClosure(pos, pat, body))) setPos pos
+
+ /* If `pat` is not yet a `Bind` wrap it in one with a fresh name */
+ def makeBind(pat: Tree): Tree = pat match {
+ case Bind(_, _) => pat
+ case _ => Bind(freshTermName(), pat) setPos pat.pos
+ }
+
+ /* A reference to the name bound in Bind `pat`. */
+ def makeValue(pat: Tree): Tree = pat match {
+ case Bind(name, _) => Ident(name) setPos pat.pos.focus
+ }
+
+ /* The position of the closure that starts with generator at position `genpos`. */
+ def closurePos(genpos: Position) = {
+ val end = body.pos match {
+ case NoPosition => genpos.point
+ case bodypos => bodypos.end
+ }
+ rangePos(genpos.source ,genpos.start, genpos.point, end)
+ }
+
+ enums match {
+ case (t @ ValFrom(pat, rhs)) :: Nil =>
+ makeCombination(closurePos(t.pos), mapName, rhs, pat, body)
+ case (t @ ValFrom(pat, rhs)) :: (rest @ (ValFrom(_, _) :: _)) =>
+ makeCombination(closurePos(t.pos), flatMapName, rhs, pat,
+ mkFor(rest, sugarBody))
+ case (t @ ValFrom(pat, rhs)) :: Filter(test) :: rest =>
+ mkFor(ValFrom(pat, makeCombination(rhs.pos union test.pos, nme.withFilter, rhs, pat.duplicate, test)).setPos(t.pos) :: rest, sugarBody)
+ case (t @ ValFrom(pat, rhs)) :: rest =>
+ val valeqs = rest.take(definitions.MaxTupleArity - 1).takeWhile { ValEq.unapply(_).nonEmpty }
+ assert(!valeqs.isEmpty)
+ val rest1 = rest.drop(valeqs.length)
+ val pats = valeqs map { case ValEq(pat, _) => pat }
+ val rhss = valeqs map { case ValEq(_, rhs) => rhs }
+ val defpat1 = makeBind(pat)
+ val defpats = pats map makeBind
+ val pdefs = (defpats, rhss).zipped flatMap mkPatDef
+ val ids = (defpat1 :: defpats) map makeValue
+ val rhs1 = mkFor(
+ List(ValFrom(defpat1, rhs).setPos(t.pos)),
+ Yield(Block(pdefs, atPos(wrappingPos(ids)) { mkTuple(ids) }) setPos wrappingPos(pdefs)))
+ val allpats = (pat :: pats) map (_.duplicate)
+ val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end))
+ mkFor(vfrom1 :: rest1, sugarBody)
+ case _ =>
+ EmptyTree //may happen for erroneous input
+ }
+ }
+
+ /** Create tree for pattern definition <val pat0 = rhs> */
+ def mkPatDef(pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator): List[Tree] =
+ mkPatDef(Modifiers(0), pat, rhs)
+
+ /** Create tree for pattern definition <mods val pat0 = rhs> */
+ def mkPatDef(mods: Modifiers, pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator): List[Tree] = matchVarPattern(pat) match {
+ case Some((name, tpt)) =>
+ List(atPos(pat.pos union rhs.pos) {
+ ValDef(mods, name.toTermName, tpt, rhs)
+ })
+
+ case None =>
+ // in case there is exactly one variable x_1 in pattern
+ // val/var p = e ==> val/var x_1 = e.match (case p => (x_1))
+ //
+ // in case there are zero or more than one variables in pattern
+ // val/var p = e ==> private synthetic val t$ = e.match (case p => (x_1, ..., x_N))
+ // val/var x_1 = t$._1
+ // ...
+ // val/var x_N = t$._N
+
+ val rhsUnchecked = mkUnchecked(rhs)
+
+ // TODO: clean this up -- there is too much information packked into mkPatDef's `pat` argument
+ // when it's a simple identifier (case Some((name, tpt)) -- above),
+ // pat should have the type ascription that was specified by the user
+ // however, in `case None` (here), we must be careful not to generate illegal pattern trees (such as `(a, b): Tuple2[Int, String]`)
+ // i.e., this must hold: pat1 match { case Typed(expr, tp) => assert(expr.isInstanceOf[Ident]) case _ => }
+ // if we encounter such an erroneous pattern, we strip off the type ascription from pat and propagate the type information to rhs
+ val (pat1, rhs1) = patvarTransformer.transform(pat) match {
+ // move the Typed ascription to the rhs
+ case Typed(expr, tpt) if !expr.isInstanceOf[Ident] =>
+ val rhsTypedUnchecked =
+ if (tpt.isEmpty) rhsUnchecked
+ else Typed(rhsUnchecked, tpt) setPos (rhs.pos union tpt.pos)
+ (expr, rhsTypedUnchecked)
+ case ok =>
+ (ok, rhsUnchecked)
+ }
+ val vars = getVariables(pat1)
+ val matchExpr = atPos((pat1.pos union rhs.pos).makeTransparent) {
+ Match(
+ rhs1,
+ List(
+ atPos(pat1.pos) {
+ CaseDef(pat1, EmptyTree, mkTuple(vars map (_._1) map Ident.apply))
+ }
+ ))
+ }
+ vars match {
+ case List((vname, tpt, pos)) =>
+ List(atPos(pat.pos union pos union rhs.pos) {
+ ValDef(mods, vname.toTermName, tpt, matchExpr)
+ })
+ case _ =>
+ val tmp = freshTermName()
+ val firstDef =
+ atPos(matchExpr.pos) {
+ ValDef(Modifiers(PrivateLocal | SYNTHETIC | ARTIFACT | (mods.flags & LAZY)),
+ tmp, TypeTree(), matchExpr)
+ }
+ var cnt = 0
+ val restDefs = for ((vname, tpt, pos) <- vars) yield atPos(pos) {
+ cnt += 1
+ ValDef(mods, vname.toTermName, tpt, Select(Ident(tmp), newTermName("_" + cnt)))
+ }
+ firstDef :: restDefs
+ }
+ }
+
+ /** Create tree for for-comprehension generator <val pat0 <- rhs0> */
+ def mkGenerator(pos: Position, pat: Tree, valeq: Boolean, rhs: Tree)(implicit fresh: FreshNameCreator): Tree = {
+ val pat1 = patvarTransformer.transform(pat)
+ if (valeq) ValEq(pat1, rhs).setPos(pos)
+ else ValFrom(pat1, mkCheckIfRefutable(pat1, rhs)).setPos(pos)
+ }
+
+ def mkCheckIfRefutable(pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator) =
+ if (treeInfo.isVarPatternDeep(pat)) rhs
+ else {
+ val cases = List(
+ CaseDef(pat.duplicate, EmptyTree, Literal(Constant(true))),
+ CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false)))
+ )
+ val visitor = mkVisitor(cases, checkExhaustive = false, nme.CHECK_IF_REFUTABLE_STRING)
+ atPos(rhs.pos)(Apply(Select(rhs, nme.withFilter), visitor :: Nil))
+ }
+
+ /** If tree is a variable pattern, return Some("its name and type").
+ * Otherwise return none */
+ private def matchVarPattern(tree: Tree): Option[(Name, Tree)] = {
+ def wildType(t: Tree): Option[Tree] = t match {
+ case Ident(x) if x.toTermName == nme.WILDCARD => Some(TypeTree())
+ case Typed(Ident(x), tpt) if x.toTermName == nme.WILDCARD => Some(tpt)
+ case _ => None
+ }
+ tree match {
+ case Ident(name) => Some((name, TypeTree()))
+ case Bind(name, body) => wildType(body) map (x => (name, x))
+ case Typed(Ident(name), tpt) => Some((name, tpt))
+ case _ => None
+ }
+ }
+
+ /** Create visitor <x => x match cases> */
+ def mkVisitor(cases: List[CaseDef], checkExhaustive: Boolean, prefix: String = "x$")(implicit fresh: FreshNameCreator): Tree = {
+ val x = freshTermName(prefix)
+ val id = Ident(x)
+ val sel = if (checkExhaustive) id else mkUnchecked(id)
+ Function(List(mkSyntheticParam(x)), Match(sel, cases))
+ }
+
+ /** Traverse pattern and collect all variable names with their types in buffer
+ * The variables keep their positions; whereas the pattern is converted to be
+ * synthetic for all nodes that contain a variable position.
+ */
+ class GetVarTraverser extends Traverser {
+ val buf = new ListBuffer[(Name, Tree, Position)]
+
+ def namePos(tree: Tree, name: Name): Position =
+ if (!tree.pos.isRange || name.containsName(nme.raw.DOLLAR)) tree.pos.focus
+ else {
+ val start = tree.pos.start
+ val end = start + name.decode.length
+ rangePos(tree.pos.source, start, start, end)
+ }
+
+ override def traverse(tree: Tree): Unit = {
+ def seenName(name: Name) = buf exists (_._1 == name)
+ def add(name: Name, t: Tree) = if (!seenName(name)) buf += ((name, t, namePos(tree, name)))
+ val bl = buf.length
+
+ tree match {
+ case Bind(nme.WILDCARD, _) =>
+ super.traverse(tree)
+
+ case Bind(name, Typed(tree1, tpt)) =>
+ val newTree = if (treeInfo.mayBeTypePat(tpt)) TypeTree() else tpt.duplicate
+ add(name, newTree)
+ traverse(tree1)
+
+ case Bind(name, tree1) =>
+ // can assume only name range as position, as otherwise might overlap
+ // with binds embedded in pattern tree1
+ add(name, TypeTree())
+ traverse(tree1)
+
+ case _ =>
+ super.traverse(tree)
+ }
+ if (buf.length > bl)
+ tree setPos tree.pos.makeTransparent
+ }
+ def apply(tree: Tree) = {
+ traverse(tree)
+ buf.toList
+ }
+ }
+
+ /** Returns list of all pattern variables, possibly with their types,
+ * without duplicates
+ */
+ private def getVariables(tree: Tree): List[(Name, Tree, Position)] =
+ new GetVarTraverser apply tree
+
+ /** Convert all occurrences of (lower-case) variables in a pattern as follows:
+ * x becomes x @ _
+ * x: T becomes x @ (_: T)
+ */
+ object patvarTransformer extends Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case Ident(name) if (treeInfo.isVarPattern(tree) && name != nme.WILDCARD) =>
+ atPos(tree.pos)(Bind(name, atPos(tree.pos.focus) (Ident(nme.WILDCARD))))
+ case Typed(id @ Ident(name), tpt) if (treeInfo.isVarPattern(id) && name != nme.WILDCARD) =>
+ atPos(tree.pos.withPoint(id.pos.point)) {
+ Bind(name, atPos(tree.pos.withStart(tree.pos.point)) {
+ Typed(Ident(nme.WILDCARD), tpt)
+ })
+ }
+ case Apply(fn @ Apply(_, _), args) =>
+ treeCopy.Apply(tree, transform(fn), transformTrees(args))
+ case Apply(fn, args) =>
+ treeCopy.Apply(tree, fn, transformTrees(args))
+ case Typed(expr, tpt) =>
+ treeCopy.Typed(tree, transform(expr), tpt)
+ case Bind(name, body) =>
+ treeCopy.Bind(tree, name, transform(body))
+ case Alternative(_) | Star(_) =>
+ super.transform(tree)
+ case _ =>
+ tree
+ }
+ }
+
// annotate the expression with @unchecked
def mkUnchecked(expr: Tree): Tree = atPos(expr.pos) {
// This can't be "Annotated(New(UncheckedClass), expr)" because annotations