From ef02d85ac6232665bac611d788d472665d15cade Mon Sep 17 00:00:00 2001 From: Den Shabalin Date: Tue, 5 Nov 2013 13:43:24 +0100 Subject: move for loop desugaring into tree gen --- src/reflect/scala/reflect/internal/TreeGen.scala | 369 +++++++++++++++++++++++ 1 file changed, 369 insertions(+) (limited to 'src/reflect/scala/reflect/internal/TreeGen.scala') diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala index b75368717f..d5eff8a9b8 100644 --- a/src/reflect/scala/reflect/internal/TreeGen.scala +++ b/src/reflect/scala/reflect/internal/TreeGen.scala @@ -466,6 +466,375 @@ abstract class TreeGen extends macros.TreeBuilder { atPos(pkgPos)(PackageDef(pid, module :: Nil)) } + object ValFrom { + def apply(pat: Tree, rhs: Tree): Tree = + Apply(Ident(nme.LARROWkw).updateAttachment(ForAttachment), + List(pat, rhs)) + + def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { + case Apply(id @ Ident(nme.LARROWkw), List(pat, rhs)) + if id.hasAttachment[ForAttachment.type] => + Some((pat, rhs)) + case _ => None + } + } + + object ValEq { + def apply(pat: Tree, rhs: Tree): Tree = + Assign(pat, rhs).updateAttachment(ForAttachment) + + def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { + case Assign(pat, rhs) + if tree.hasAttachment[ForAttachment.type] => + Some((pat, rhs)) + case _ => None + } + } + + object Filter { + def apply(tree: Tree) = + Apply(Ident(nme.IFkw).updateAttachment(ForAttachment), List(tree)) + + def unapply(tree: Tree): Option[Tree] = tree match { + case Apply(id @ Ident(nme.IFkw), List(cond)) + if id.hasAttachment[ForAttachment.type] => + Some((cond)) + case _ => None + } + } + + object Yield { + def apply(tree: Tree): Tree = + Apply(Ident(nme.YIELDkw).updateAttachment(ForAttachment), List(tree)) + + def unapply(tree: Tree): Option[Tree] = tree match { + case Apply(id @ Ident(nme.YIELDkw), List(tree)) + if id.hasAttachment[ForAttachment.type] => + Some(tree) + case _ => None + } + } + + /** Create tree for for-comprehension or + * where mapName and flatMapName are chosen + * corresponding to whether this is a for-do or a for-yield. + * The creation performs the following rewrite rules: + * + * 1. + * + * for (P <- G) E ==> G.foreach (P => E) + * + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. + * + * 2. + * + * for (P <- G) yield E ==> G.map (P => E) + * + * 3. + * + * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) + * + * 4. + * + * for (P <- G; E; ...) ... + * => + * for (P <- G.filter (P => E); ...) ... + * + * 5. For N < MaxTupleArity: + * + * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * ==> + * for (TupleN(P_1, P_2, ... P_N) <- + * for (x_1 @ P_1 <- G) yield { + * val x_2 @ P_2 = E_2 + * ... + * val x_N & P_N = E_N + * TupleN(x_1, ..., x_N) + * } ...) + * + * If any of the P_i are variable patterns, the corresponding `x_i @ P_i' is not generated + * and the variable constituting P_i is used instead of x_i + * + * @param mapName The name to be used for maps (either map or foreach) + * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) + * @param enums The enumerators in the for expression + * @param body The body of the for expression + */ + def mkFor(enums: List[Tree], sugarBody: Tree)(implicit fresh: FreshNameCreator): Tree = { + val (mapName, flatMapName, body) = sugarBody match { + case Yield(tree) => (nme.map, nme.flatMap, tree) + case _ => (nme.foreach, nme.foreach, sugarBody) + } + + /* make a closure pat => body. + * The closure is assigned a transparent position with the point at pos.point and + * the limits given by pat and body. + */ + def makeClosure(pos: Position, pat: Tree, body: Tree): Tree = { + def splitpos = wrappingPos(List(pat, body)).withPoint(pos.point).makeTransparent + matchVarPattern(pat) match { + case Some((name, tpt)) => + Function( + List(atPos(pat.pos) { ValDef(Modifiers(PARAM), name.toTermName, tpt, EmptyTree) }), + body) setPos splitpos + case None => + atPos(splitpos) { + mkVisitor(List(CaseDef(pat, EmptyTree, body)), checkExhaustive = false) + } + } + } + + /* Make an application qual.meth(pat => body) positioned at `pos`. + */ + def makeCombination(pos: Position, meth: TermName, qual: Tree, pat: Tree, body: Tree): Tree = + Apply(Select(qual, meth) setPos qual.pos, List(makeClosure(pos, pat, body))) setPos pos + + /* If `pat` is not yet a `Bind` wrap it in one with a fresh name */ + def makeBind(pat: Tree): Tree = pat match { + case Bind(_, _) => pat + case _ => Bind(freshTermName(), pat) setPos pat.pos + } + + /* A reference to the name bound in Bind `pat`. */ + def makeValue(pat: Tree): Tree = pat match { + case Bind(name, _) => Ident(name) setPos pat.pos.focus + } + + /* The position of the closure that starts with generator at position `genpos`. */ + def closurePos(genpos: Position) = { + val end = body.pos match { + case NoPosition => genpos.point + case bodypos => bodypos.end + } + rangePos(genpos.source ,genpos.start, genpos.point, end) + } + + enums match { + case (t @ ValFrom(pat, rhs)) :: Nil => + makeCombination(closurePos(t.pos), mapName, rhs, pat, body) + case (t @ ValFrom(pat, rhs)) :: (rest @ (ValFrom(_, _) :: _)) => + makeCombination(closurePos(t.pos), flatMapName, rhs, pat, + mkFor(rest, sugarBody)) + case (t @ ValFrom(pat, rhs)) :: Filter(test) :: rest => + mkFor(ValFrom(pat, makeCombination(rhs.pos union test.pos, nme.withFilter, rhs, pat.duplicate, test)).setPos(t.pos) :: rest, sugarBody) + case (t @ ValFrom(pat, rhs)) :: rest => + val valeqs = rest.take(definitions.MaxTupleArity - 1).takeWhile { ValEq.unapply(_).nonEmpty } + assert(!valeqs.isEmpty) + val rest1 = rest.drop(valeqs.length) + val pats = valeqs map { case ValEq(pat, _) => pat } + val rhss = valeqs map { case ValEq(_, rhs) => rhs } + val defpat1 = makeBind(pat) + val defpats = pats map makeBind + val pdefs = (defpats, rhss).zipped flatMap mkPatDef + val ids = (defpat1 :: defpats) map makeValue + val rhs1 = mkFor( + List(ValFrom(defpat1, rhs).setPos(t.pos)), + Yield(Block(pdefs, atPos(wrappingPos(ids)) { mkTuple(ids) }) setPos wrappingPos(pdefs))) + val allpats = (pat :: pats) map (_.duplicate) + val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end)) + mkFor(vfrom1 :: rest1, sugarBody) + case _ => + EmptyTree //may happen for erroneous input + } + } + + /** Create tree for pattern definition */ + def mkPatDef(pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator): List[Tree] = + mkPatDef(Modifiers(0), pat, rhs) + + /** Create tree for pattern definition */ + def mkPatDef(mods: Modifiers, pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator): List[Tree] = matchVarPattern(pat) match { + case Some((name, tpt)) => + List(atPos(pat.pos union rhs.pos) { + ValDef(mods, name.toTermName, tpt, rhs) + }) + + case None => + // in case there is exactly one variable x_1 in pattern + // val/var p = e ==> val/var x_1 = e.match (case p => (x_1)) + // + // in case there are zero or more than one variables in pattern + // val/var p = e ==> private synthetic val t$ = e.match (case p => (x_1, ..., x_N)) + // val/var x_1 = t$._1 + // ... + // val/var x_N = t$._N + + val rhsUnchecked = mkUnchecked(rhs) + + // TODO: clean this up -- there is too much information packked into mkPatDef's `pat` argument + // when it's a simple identifier (case Some((name, tpt)) -- above), + // pat should have the type ascription that was specified by the user + // however, in `case None` (here), we must be careful not to generate illegal pattern trees (such as `(a, b): Tuple2[Int, String]`) + // i.e., this must hold: pat1 match { case Typed(expr, tp) => assert(expr.isInstanceOf[Ident]) case _ => } + // if we encounter such an erroneous pattern, we strip off the type ascription from pat and propagate the type information to rhs + val (pat1, rhs1) = patvarTransformer.transform(pat) match { + // move the Typed ascription to the rhs + case Typed(expr, tpt) if !expr.isInstanceOf[Ident] => + val rhsTypedUnchecked = + if (tpt.isEmpty) rhsUnchecked + else Typed(rhsUnchecked, tpt) setPos (rhs.pos union tpt.pos) + (expr, rhsTypedUnchecked) + case ok => + (ok, rhsUnchecked) + } + val vars = getVariables(pat1) + val matchExpr = atPos((pat1.pos union rhs.pos).makeTransparent) { + Match( + rhs1, + List( + atPos(pat1.pos) { + CaseDef(pat1, EmptyTree, mkTuple(vars map (_._1) map Ident.apply)) + } + )) + } + vars match { + case List((vname, tpt, pos)) => + List(atPos(pat.pos union pos union rhs.pos) { + ValDef(mods, vname.toTermName, tpt, matchExpr) + }) + case _ => + val tmp = freshTermName() + val firstDef = + atPos(matchExpr.pos) { + ValDef(Modifiers(PrivateLocal | SYNTHETIC | ARTIFACT | (mods.flags & LAZY)), + tmp, TypeTree(), matchExpr) + } + var cnt = 0 + val restDefs = for ((vname, tpt, pos) <- vars) yield atPos(pos) { + cnt += 1 + ValDef(mods, vname.toTermName, tpt, Select(Ident(tmp), newTermName("_" + cnt))) + } + firstDef :: restDefs + } + } + + /** Create tree for for-comprehension generator */ + def mkGenerator(pos: Position, pat: Tree, valeq: Boolean, rhs: Tree)(implicit fresh: FreshNameCreator): Tree = { + val pat1 = patvarTransformer.transform(pat) + if (valeq) ValEq(pat1, rhs).setPos(pos) + else ValFrom(pat1, mkCheckIfRefutable(pat1, rhs)).setPos(pos) + } + + def mkCheckIfRefutable(pat: Tree, rhs: Tree)(implicit fresh: FreshNameCreator) = + if (treeInfo.isVarPatternDeep(pat)) rhs + else { + val cases = List( + CaseDef(pat.duplicate, EmptyTree, Literal(Constant(true))), + CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))) + ) + val visitor = mkVisitor(cases, checkExhaustive = false, nme.CHECK_IF_REFUTABLE_STRING) + atPos(rhs.pos)(Apply(Select(rhs, nme.withFilter), visitor :: Nil)) + } + + /** If tree is a variable pattern, return Some("its name and type"). + * Otherwise return none */ + private def matchVarPattern(tree: Tree): Option[(Name, Tree)] = { + def wildType(t: Tree): Option[Tree] = t match { + case Ident(x) if x.toTermName == nme.WILDCARD => Some(TypeTree()) + case Typed(Ident(x), tpt) if x.toTermName == nme.WILDCARD => Some(tpt) + case _ => None + } + tree match { + case Ident(name) => Some((name, TypeTree())) + case Bind(name, body) => wildType(body) map (x => (name, x)) + case Typed(Ident(name), tpt) => Some((name, tpt)) + case _ => None + } + } + + /** Create visitor x match cases> */ + def mkVisitor(cases: List[CaseDef], checkExhaustive: Boolean, prefix: String = "x$")(implicit fresh: FreshNameCreator): Tree = { + val x = freshTermName(prefix) + val id = Ident(x) + val sel = if (checkExhaustive) id else mkUnchecked(id) + Function(List(mkSyntheticParam(x)), Match(sel, cases)) + } + + /** Traverse pattern and collect all variable names with their types in buffer + * The variables keep their positions; whereas the pattern is converted to be + * synthetic for all nodes that contain a variable position. + */ + class GetVarTraverser extends Traverser { + val buf = new ListBuffer[(Name, Tree, Position)] + + def namePos(tree: Tree, name: Name): Position = + if (!tree.pos.isRange || name.containsName(nme.raw.DOLLAR)) tree.pos.focus + else { + val start = tree.pos.start + val end = start + name.decode.length + rangePos(tree.pos.source, start, start, end) + } + + override def traverse(tree: Tree): Unit = { + def seenName(name: Name) = buf exists (_._1 == name) + def add(name: Name, t: Tree) = if (!seenName(name)) buf += ((name, t, namePos(tree, name))) + val bl = buf.length + + tree match { + case Bind(nme.WILDCARD, _) => + super.traverse(tree) + + case Bind(name, Typed(tree1, tpt)) => + val newTree = if (treeInfo.mayBeTypePat(tpt)) TypeTree() else tpt.duplicate + add(name, newTree) + traverse(tree1) + + case Bind(name, tree1) => + // can assume only name range as position, as otherwise might overlap + // with binds embedded in pattern tree1 + add(name, TypeTree()) + traverse(tree1) + + case _ => + super.traverse(tree) + } + if (buf.length > bl) + tree setPos tree.pos.makeTransparent + } + def apply(tree: Tree) = { + traverse(tree) + buf.toList + } + } + + /** Returns list of all pattern variables, possibly with their types, + * without duplicates + */ + private def getVariables(tree: Tree): List[(Name, Tree, Position)] = + new GetVarTraverser apply tree + + /** Convert all occurrences of (lower-case) variables in a pattern as follows: + * x becomes x @ _ + * x: T becomes x @ (_: T) + */ + object patvarTransformer extends Transformer { + override def transform(tree: Tree): Tree = tree match { + case Ident(name) if (treeInfo.isVarPattern(tree) && name != nme.WILDCARD) => + atPos(tree.pos)(Bind(name, atPos(tree.pos.focus) (Ident(nme.WILDCARD)))) + case Typed(id @ Ident(name), tpt) if (treeInfo.isVarPattern(id) && name != nme.WILDCARD) => + atPos(tree.pos.withPoint(id.pos.point)) { + Bind(name, atPos(tree.pos.withStart(tree.pos.point)) { + Typed(Ident(nme.WILDCARD), tpt) + }) + } + case Apply(fn @ Apply(_, _), args) => + treeCopy.Apply(tree, transform(fn), transformTrees(args)) + case Apply(fn, args) => + treeCopy.Apply(tree, fn, transformTrees(args)) + case Typed(expr, tpt) => + treeCopy.Typed(tree, transform(expr), tpt) + case Bind(name, body) => + treeCopy.Bind(tree, name, transform(body)) + case Alternative(_) | Star(_) => + super.transform(tree) + case _ => + tree + } + } + // annotate the expression with @unchecked def mkUnchecked(expr: Tree): Tree = atPos(expr.pos) { // This can't be "Annotated(New(UncheckedClass), expr)" because annotations -- cgit v1.2.3