summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/ast/parser/Parsers.scala18
-rw-r--r--src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala169
-rw-r--r--src/compiler/scala/tools/nsc/symtab/StdNames.scala1
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/RefChecks.scala31
-rw-r--r--test/files/run/forvaleq.check9
-rw-r--r--test/files/run/forvaleq.scala17
6 files changed, 197 insertions, 48 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
index 5f74667dc4..329b2604c7 100644
--- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
+++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
@@ -931,21 +931,25 @@ mixin class Parsers requires SyntaxAnalyzer {
* Enumerator ::= Generator
* | Expr
*/
- def enumerators(): List[Tree] = {
- val enums = new ListBuffer[Tree] + generator();
+ def enumerators(): List[Enumerator] = {
+ val enums = new ListBuffer[Enumerator] + generator(false);
while (in.token == SEMI || in.token == NEWLINE) {
in.nextToken();
- enums += (if (in.token == VAL) generator() else expr())
+ enums += (if (in.token == VAL) generator(true) else Filter(expr()))
}
enums.toList
}
/** Generator ::= val Pattern1 `<-' Expr
*/
- def generator(): Tree =
- atPos(accept(VAL)) {
- makeGenerator(pattern1(false), { accept(LARROW); expr() })
- }
+ def generator(eqOK: boolean): Enumerator = {
+ val pos = accept(VAL);
+ val pat = pattern1(false);
+ val tok = in.token;
+ if (tok == EQUALS && eqOK) in.nextToken()
+ else accept(LARROW);
+ makeGenerator(pos, pat, tok == EQUALS, expr)
+ }
//////// PATTERNS ////////////////////////////////////////////////////////////
diff --git a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
index 611defb2b0..8a91c603d6 100644
--- a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
+++ b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
@@ -18,6 +18,10 @@ abstract class TreeBuilder {
def freshName(): Name = freshName("x$");
+ /** Convert all occurrences of (lower-case) variables in a pattern as follows:
+ * x becomes x @ _
+ * x: T becomes x @ (_: T)
+ */
private object patvarTransformer extends Transformer {
override def transform(tree: Tree): Tree = tree match {
case Ident(name) if (treeInfo.isVariableName(name) && name != nme.WILDCARD) =>
@@ -57,10 +61,17 @@ abstract class TreeBuilder {
getvarTraverser.buf.toList
}
- private def mkTuple(trees: List[Tree]): Tree = trees match {
+ private def makeTuple(trees: List[Tree], isType: boolean): Tree = {
+ val tupString = "Tuple" + trees.length;
+ Apply(
+ Select(Ident(nme.scala_), if (isType) newTypeName(tupString) else newTermName(tupString)),
+ trees)
+ }
+
+ private def makeTupleTerm(trees: List[Tree]): Tree = trees match {
case List() => Literal(())
case List(tree) => tree
- case _ => Apply(Select(Ident(nme.scala_), newTermName("Tuple" + trees.length)), trees)
+ case _ => makeTuple(trees, false)
}
/** If tree is a variable pattern, return Some("its name and type").
@@ -136,57 +147,139 @@ abstract class TreeBuilder {
}
/** Create tree for for-comprehension generator <val pat0 <- rhs0> */
- def makeGenerator(pat: Tree, rhs: Tree): Tree = {
+ def makeGenerator(pos: int, pat: Tree, valeq: boolean, rhs: Tree): Enumerator = {
val pat1 = patvarTransformer.transform(pat);
- val rhs1 = matchVarPattern(pat1) match {
- case Some(_) =>
- rhs
- case None =>
- Apply(
- Select(rhs, nme.filter),
- List(makeVisitor(List(
- CaseDef(pat1.duplicate, EmptyTree, Literal(true)),
- CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(false))))))
- }
- CaseDef(pat1, EmptyTree, rhs1)
+ val rhs1 =
+ if (valeq) rhs
+ else matchVarPattern(pat1) match {
+ case Some(_) =>
+ rhs
+ case None =>
+ atPos(pos) {
+ Apply(
+ Select(rhs, nme.filter),
+ List(
+ makeVisitor(
+ List(
+ CaseDef(pat1.duplicate, EmptyTree, Literal(true)),
+ CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(false))),
+ nme.CHECK_IF_REFUTABLE_STRING
+ )))
+ }
+ }
+ if (valeq) ValEq(pos, pat1, rhs1) else ValFrom(pos, pat1, rhs1)
}
+ abstract class Enumerator
+ case class ValFrom(pos: int, pat: Tree, rhs: Tree) extends Enumerator
+ case class ValEq(pos: int, pat: Tree, rhs: Tree) extends Enumerator
+ case class Filter(test: Tree) extends Enumerator
+
/** Create tree for for-comprehension <for (enums) do body> or
* <for (enums) yield body> where mapName and flatMapName are chosen
* corresponding to whether this is a for-do or a for-yield.
+ * The creation performs the following rewrite rules:
+ *
+ * 1.
+ *
+ * for (val P <- G) E ==> G.foreach (P => E)
+ *
+ * Here and in the following (P => E) is interpreted as the function (P => E)
+ * if P is a a variable pattern and as the partial function { case P => E } otherwise.
+ *
+ * 2.
+ *
+ * for (val P <- G) yield E ==> G.map (P => E)
+ *
+ * 3.
+ *
+ * for (val P_1 <- G_1; val P_2 <- G_2; ...) ...
+ * ==>
+ * G_1.flatMap (P_1 => for (val P_2 <- G_2; ...) ...)
+ *
+ * 4.
+ *
+ * for (val P <- G; E; ...) ...
+ * =>
+ * for (val P <- G.filter (P => E); ...) ...
+ *
+ * 5. For N < MaxTupleArity:
+ *
+ * for (val P_1 <- G; val P_2 = E_2; val P_N = E_N; ...)
+ * ==>
+ * for (val TupleN(P_1, P_2, ... P_N) <-
+ * for (val x_1 @ P_1 <- G) yield {
+ * val x_2 @ P_2 = E_2
+ * ...
+ * val x_N & P_N = E_N
+ * TupleN(x_1, ..., x_N)
+ * } ...)
+ *
+ * If any of the P_i are variable patterns, the corresponding `x_i @ P_i' is not generated
+ * and the variable constituting P_i is used instead of x_i
+ *
*/
- private def makeFor(mapName: Name, flatMapName: Name, enums: List[Tree], body: Tree): Tree = {
+ private def makeFor(mapName: Name, flatMapName: Name, enums: List[Enumerator], body: Tree): Tree = {
- def makeCont(pat: Tree, body: Tree): Tree = matchVarPattern(pat) match {
+ def makeClosure(pat: Tree, body: Tree): Tree = matchVarPattern(pat) match {
case Some(Pair(name, tpt)) =>
Function(List(ValDef(Modifiers(PARAM), name, tpt, EmptyTree)), body)
case None =>
makeVisitor(List(CaseDef(pat, EmptyTree, body)))
}
- def makeBind(meth: Name, qual: Tree, pat: Tree, body: Tree): Tree =
- Apply(Select(qual, meth), List(makeCont(pat, body)));
-
- atPos(enums.head.pos) {
- enums match {
- case CaseDef(pat, g, rhs) :: Nil =>
- makeBind(mapName, rhs, pat, body)
- case CaseDef(pat, g, rhs) :: (rest @ (CaseDef(_, _, _) :: _)) =>
- makeBind(flatMapName, rhs, pat, makeFor(mapName, flatMapName, rest, body))
- case CaseDef(pat, g, rhs) :: test :: rest =>
- makeFor(mapName, flatMapName,
- CaseDef(pat, g, makeBind(nme.filter, rhs, pat.duplicate, test)) :: rest,
- body)
- }
+ def makeCombination(meth: Name, qual: Tree, pat: Tree, body: Tree): Tree =
+ Apply(Select(qual, meth), List(makeClosure(pat, body)));
+
+ def patternVar(pat: Tree): Option[Name] = pat match {
+ case Bind(name, _) => Some(name)
+ case _ => None
+ }
+
+ def makeBind(pat: Tree): Tree = pat match {
+ case Bind(_, _) => pat
+ case _ => Bind(freshName(), pat)
+ }
+
+ def makeValue(pat: Tree): Tree = pat match {
+ case Bind(name, _) => Ident(name)
+ }
+
+ enums match {
+ case ValFrom(pos, pat, rhs) :: Nil =>
+ atPos(pos) {
+ makeCombination(mapName, rhs, pat, body)
+ }
+ case ValFrom(pos, pat, rhs) :: (rest @ (ValFrom(_, _, _) :: _)) =>
+ atPos(pos) {
+ makeCombination(flatMapName, rhs, pat, makeFor(mapName, flatMapName, rest, body))
+ }
+ case ValFrom(pos, pat, rhs) :: Filter(test) :: rest =>
+ makeFor(mapName, flatMapName,
+ ValFrom(pos, pat, makeCombination(nme.filter, rhs, pat.duplicate, test)) :: rest,
+ body)
+ case ValFrom(pos, pat, rhs) :: rest =>
+ val valeqs = rest.take(definitions.MaxTupleArity - 1).takeWhile(.isInstanceOf[ValEq]);
+ assert(!valeqs.isEmpty);
+ val rest1 = rest.drop(valeqs.length);
+ val pats = valeqs map { case ValEq(_, pat, _) => pat }
+ val rhss = valeqs map { case ValEq(_, _, rhs) => rhs }
+ val defpats = pats map (x => makeBind(x.duplicate))
+ val pdefs = List.flatten(List.map2(defpats, rhss)(makePatDef))
+ val ids = (pat :: defpats) map makeValue
+ val rhs1 = makeForYield(
+ List(ValFrom(pos, makeBind(pat.duplicate), rhs)),
+ Block(pdefs, makeTupleTerm(ids)))
+ makeFor(mapName, flatMapName, ValFrom(pos, makeTuple(pat :: pats, true), rhs1) :: rest1, body)
}
}
/** Create tree for for-do comprehension <for (enums) body> */
- def makeFor(enums: List[Tree], body: Tree): Tree =
+ def makeFor(enums: List[Enumerator], body: Tree): Tree =
makeFor(nme.foreach, nme.foreach, enums, body);
/** Create tree for for-yield comprehension <for (enums) yield body> */
- def makeForYield(enums: List[Tree], body: Tree): Tree =
+ def makeForYield(enums: List[Enumerator], body: Tree): Tree =
makeFor(nme.map, nme.flatMap, enums, body);
/** Create tree for a pattern alternative */
@@ -216,8 +309,11 @@ abstract class TreeBuilder {
makeAlternative(List(p, Sequence(List())));
/** Create visitor <x => x match cases> */
- def makeVisitor(cases: List[CaseDef]): Tree = {
- val x = freshName();
+ def makeVisitor(cases: List[CaseDef]): Tree = makeVisitor(cases, "x$");
+
+ /** Create visitor <x => x match cases> */
+ def makeVisitor(cases: List[CaseDef], prefix: String): Tree = {
+ val x = freshName(prefix);
Function(List(ValDef(Modifiers(PARAM | SYNTHETIC), x, TypeTree(), EmptyTree)), Match(Ident(x), cases))
}
@@ -226,6 +322,9 @@ abstract class TreeBuilder {
CaseDef(patvarTransformer.transform(pat), guard, rhs);
}
+ /** Create tree for pattern definition <val pat0 = rhs> */
+ def makePatDef(pat: Tree, rhs: Tree): List[Tree] = makePatDef(Modifiers(0), pat, rhs)
+
/** Create tree for pattern definition <mods val pat0 = rhs> */
def makePatDef(mods: Modifiers, pat: Tree, rhs: Tree): List[Tree] = matchVarPattern(pat) match {
case Some(Pair(name, tpt)) =>
@@ -246,7 +345,7 @@ abstract class TreeBuilder {
val pat1 = patvarTransformer.transform(pat);
val vars = getVariables(pat1);
val matchExpr = atPos(pat1.pos){
- Match(rhs, List(CaseDef(pat1, EmptyTree, mkTuple(vars map Ident))))
+ Match(rhs, List(CaseDef(pat1, EmptyTree, makeTupleTerm(vars map Ident))))
}
vars match {
case List() =>
diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
index 26e941cd87..0e18f0e739 100644
--- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala
+++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
@@ -68,6 +68,7 @@ mixin class StdNames requires SymbolTable {
val SUPER_PREFIX_STRING = "super$";
val EXPAND_SEPARATOR_STRING = "$$";
val TUPLE_FIELD_PREFIX_STRING = "_";
+ val CHECK_IF_REFUTABLE_STRING = "check$ifrefutable$";
def LOCAL(clazz: Symbol) = newTermName(LOCALDUMMY_PREFIX_STRING + clazz.name);
def TUPLE_FIELD(index: int) = newTermName(TUPLE_FIELD_PREFIX_STRING + index);
diff --git a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala
index 7057638749..5c43c640ec 100644
--- a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala
@@ -504,6 +504,28 @@ abstract class RefChecks extends InfoTransform {
case ex: TypeError => unit.error(tree.pos, ex.getMessage());
}
+ def isIrrefutable(pat: Tree, seltpe: Type): boolean = {
+ val result = pat match {
+ case Apply(_, args) =>
+ val clazz = pat.tpe.symbol;
+ clazz == seltpe.symbol &&
+ clazz.isClass && (clazz hasFlag CASE) &&
+ List.forall2(
+ args,
+ clazz.primaryConstructor.tpe.asSeenFrom(seltpe, clazz).paramTypes)(isIrrefutable)
+ case Typed(pat, tpt) =>
+ seltpe <:< tpt.tpe
+ case Ident(nme.WILDCARD) =>
+ true
+ case Bind(_, pat) =>
+ isIrrefutable(pat, seltpe)
+ case _ =>
+ false
+ }
+ //System.out.println("is irefutable? " + pat + ":" + pat.tpe + " against " + seltpe + ": " + result);//DEBUG
+ result
+ }
+
val savedLocalTyper = localTyper;
val sym = tree.symbol;
var result = tree;
@@ -541,6 +563,15 @@ abstract class RefChecks extends InfoTransform {
checkBounds(fn.tpe.typeParams, args map (.tpe));
if (sym.isSourceMethod && sym.hasFlag(CASE)) result = toConstructor;
+ case Apply(
+ Select(qual, nme.filter),
+ List(Function(
+ List(ValDef(_, pname, tpt, _)),
+ Match(_, CaseDef(pat1, _, _) :: _))))
+ if ((pname startsWith nme.CHECK_IF_REFUTABLE_STRING) &&
+ isIrrefutable(pat1, tpt.tpe)) =>
+ result = qual
+
case New(tpt) =>
enterReference(tree.pos, tpt.tpe.symbol);
diff --git a/test/files/run/forvaleq.check b/test/files/run/forvaleq.check
index 141ac1ebfe..27f5269789 100644
--- a/test/files/run/forvaleq.check
+++ b/test/files/run/forvaleq.check
@@ -1,4 +1,5 @@
-List(2, 6, 10, 14, 18, 22, 24, 26, 28, 30, 32, 34, 36, 38)
-List(2, 6, 10, 14, 18, 22, 24, 26, 28, 30, 32, 34, 36, 38)
-List(2, 6, 10, 14, 18, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
-called 21 times
+List(2,6,10,14,18,20,22,24,26,28,30,32,34,36,38)
+List(2,6,10,14,18,2,2,2,2,2,2,2,2,2,2)
+List(2,6,10,14,18,20,22,24,26,28,30,32,34,36,38)
+List(2,6,10,14,18,2,2,2,2,2,2,2,2,2,2)
+called 20 times
diff --git a/test/files/run/forvaleq.scala b/test/files/run/forvaleq.scala
index 47bc545d47..a30c659ecc 100644
--- a/test/files/run/forvaleq.scala
+++ b/test/files/run/forvaleq.scala
@@ -7,7 +7,6 @@ import scala.{List=>L}
object Test {
// redefine some symbols to make it extra hard
class List
- class Pair
class Tuple2
def List[A](as:A*) = 5
@@ -33,6 +32,20 @@ object Test {
}
{
+ // a test case with patterns
+
+ val input = L.range(0,20)
+ val oddFirstTimesTwo =
+ for{val x <- input
+ val xf = firstDigit(x)
+ val yf = x - firstDigit(x) / 10
+ val Pair(a, b) = Pair(xf - yf, xf + yf)
+ xf % 2 == 1}
+ yield a + b
+ Console.println(oddFirstTimesTwo)
+ }
+
+ {
// make sure it works on non-Ls
// val input: Queue = Queue.Empty[int].incl(L.range(0,20))
@@ -42,7 +55,7 @@ object Test {
val xf = firstDigit(x)
xf % 2 == 1}
yield x*2
- Console.println(oddFirstTimesTwo.toL)
+ Console.println(oddFirstTimesTwo.toList)
}
{