From c19726b180758dc3b9d4dd070dff626fce5836d7 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Fri, 10 Feb 2006 14:49:29 +0000 Subject: --- .../scala/tools/nsc/ast/parser/Parsers.scala | 18 ++- .../scala/tools/nsc/ast/parser/TreeBuilder.scala | 169 ++++++++++++++++----- src/compiler/scala/tools/nsc/symtab/StdNames.scala | 1 + .../scala/tools/nsc/typechecker/RefChecks.scala | 31 ++++ test/files/run/forvaleq.check | 9 +- test/files/run/forvaleq.scala | 17 ++- 6 files changed, 197 insertions(+), 48 deletions(-) diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala index 5f74667dc4..329b2604c7 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala @@ -931,21 +931,25 @@ mixin class Parsers requires SyntaxAnalyzer { * Enumerator ::= Generator * | Expr */ - def enumerators(): List[Tree] = { - val enums = new ListBuffer[Tree] + generator(); + def enumerators(): List[Enumerator] = { + val enums = new ListBuffer[Enumerator] + generator(false); while (in.token == SEMI || in.token == NEWLINE) { in.nextToken(); - enums += (if (in.token == VAL) generator() else expr()) + enums += (if (in.token == VAL) generator(true) else Filter(expr())) } enums.toList } /** Generator ::= val Pattern1 `<-' Expr */ - def generator(): Tree = - atPos(accept(VAL)) { - makeGenerator(pattern1(false), { accept(LARROW); expr() }) - } + def generator(eqOK: boolean): Enumerator = { + val pos = accept(VAL); + val pat = pattern1(false); + val tok = in.token; + if (tok == EQUALS && eqOK) in.nextToken() + else accept(LARROW); + makeGenerator(pos, pat, tok == EQUALS, expr) + } //////// PATTERNS //////////////////////////////////////////////////////////// diff --git a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala index 611defb2b0..8a91c603d6 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala @@ -18,6 +18,10 @@ abstract class TreeBuilder { def freshName(): Name = freshName("x$"); + /** Convert all occurrences of (lower-case) variables in a pattern as follows: + * x becomes x @ _ + * x: T becomes x @ (_: T) + */ private object patvarTransformer extends Transformer { override def transform(tree: Tree): Tree = tree match { case Ident(name) if (treeInfo.isVariableName(name) && name != nme.WILDCARD) => @@ -57,10 +61,17 @@ abstract class TreeBuilder { getvarTraverser.buf.toList } - private def mkTuple(trees: List[Tree]): Tree = trees match { + private def makeTuple(trees: List[Tree], isType: boolean): Tree = { + val tupString = "Tuple" + trees.length; + Apply( + Select(Ident(nme.scala_), if (isType) newTypeName(tupString) else newTermName(tupString)), + trees) + } + + private def makeTupleTerm(trees: List[Tree]): Tree = trees match { case List() => Literal(()) case List(tree) => tree - case _ => Apply(Select(Ident(nme.scala_), newTermName("Tuple" + trees.length)), trees) + case _ => makeTuple(trees, false) } /** If tree is a variable pattern, return Some("its name and type"). @@ -136,57 +147,139 @@ abstract class TreeBuilder { } /** Create tree for for-comprehension generator */ - def makeGenerator(pat: Tree, rhs: Tree): Tree = { + def makeGenerator(pos: int, pat: Tree, valeq: boolean, rhs: Tree): Enumerator = { val pat1 = patvarTransformer.transform(pat); - val rhs1 = matchVarPattern(pat1) match { - case Some(_) => - rhs - case None => - Apply( - Select(rhs, nme.filter), - List(makeVisitor(List( - CaseDef(pat1.duplicate, EmptyTree, Literal(true)), - CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(false)))))) - } - CaseDef(pat1, EmptyTree, rhs1) + val rhs1 = + if (valeq) rhs + else matchVarPattern(pat1) match { + case Some(_) => + rhs + case None => + atPos(pos) { + Apply( + Select(rhs, nme.filter), + List( + makeVisitor( + List( + CaseDef(pat1.duplicate, EmptyTree, Literal(true)), + CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(false))), + nme.CHECK_IF_REFUTABLE_STRING + ))) + } + } + if (valeq) ValEq(pos, pat1, rhs1) else ValFrom(pos, pat1, rhs1) } + abstract class Enumerator + case class ValFrom(pos: int, pat: Tree, rhs: Tree) extends Enumerator + case class ValEq(pos: int, pat: Tree, rhs: Tree) extends Enumerator + case class Filter(test: Tree) extends Enumerator + /** Create tree for for-comprehension or * where mapName and flatMapName are chosen * corresponding to whether this is a for-do or a for-yield. + * The creation performs the following rewrite rules: + * + * 1. + * + * for (val P <- G) E ==> G.foreach (P => E) + * + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a a variable pattern and as the partial function { case P => E } otherwise. + * + * 2. + * + * for (val P <- G) yield E ==> G.map (P => E) + * + * 3. + * + * for (val P_1 <- G_1; val P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (val P_2 <- G_2; ...) ...) + * + * 4. + * + * for (val P <- G; E; ...) ... + * => + * for (val P <- G.filter (P => E); ...) ... + * + * 5. For N < MaxTupleArity: + * + * for (val P_1 <- G; val P_2 = E_2; val P_N = E_N; ...) + * ==> + * for (val TupleN(P_1, P_2, ... P_N) <- + * for (val x_1 @ P_1 <- G) yield { + * val x_2 @ P_2 = E_2 + * ... + * val x_N & P_N = E_N + * TupleN(x_1, ..., x_N) + * } ...) + * + * If any of the P_i are variable patterns, the corresponding `x_i @ P_i' is not generated + * and the variable constituting P_i is used instead of x_i + * */ - private def makeFor(mapName: Name, flatMapName: Name, enums: List[Tree], body: Tree): Tree = { + private def makeFor(mapName: Name, flatMapName: Name, enums: List[Enumerator], body: Tree): Tree = { - def makeCont(pat: Tree, body: Tree): Tree = matchVarPattern(pat) match { + def makeClosure(pat: Tree, body: Tree): Tree = matchVarPattern(pat) match { case Some(Pair(name, tpt)) => Function(List(ValDef(Modifiers(PARAM), name, tpt, EmptyTree)), body) case None => makeVisitor(List(CaseDef(pat, EmptyTree, body))) } - def makeBind(meth: Name, qual: Tree, pat: Tree, body: Tree): Tree = - Apply(Select(qual, meth), List(makeCont(pat, body))); - - atPos(enums.head.pos) { - enums match { - case CaseDef(pat, g, rhs) :: Nil => - makeBind(mapName, rhs, pat, body) - case CaseDef(pat, g, rhs) :: (rest @ (CaseDef(_, _, _) :: _)) => - makeBind(flatMapName, rhs, pat, makeFor(mapName, flatMapName, rest, body)) - case CaseDef(pat, g, rhs) :: test :: rest => - makeFor(mapName, flatMapName, - CaseDef(pat, g, makeBind(nme.filter, rhs, pat.duplicate, test)) :: rest, - body) - } + def makeCombination(meth: Name, qual: Tree, pat: Tree, body: Tree): Tree = + Apply(Select(qual, meth), List(makeClosure(pat, body))); + + def patternVar(pat: Tree): Option[Name] = pat match { + case Bind(name, _) => Some(name) + case _ => None + } + + def makeBind(pat: Tree): Tree = pat match { + case Bind(_, _) => pat + case _ => Bind(freshName(), pat) + } + + def makeValue(pat: Tree): Tree = pat match { + case Bind(name, _) => Ident(name) + } + + enums match { + case ValFrom(pos, pat, rhs) :: Nil => + atPos(pos) { + makeCombination(mapName, rhs, pat, body) + } + case ValFrom(pos, pat, rhs) :: (rest @ (ValFrom(_, _, _) :: _)) => + atPos(pos) { + makeCombination(flatMapName, rhs, pat, makeFor(mapName, flatMapName, rest, body)) + } + case ValFrom(pos, pat, rhs) :: Filter(test) :: rest => + makeFor(mapName, flatMapName, + ValFrom(pos, pat, makeCombination(nme.filter, rhs, pat.duplicate, test)) :: rest, + body) + case ValFrom(pos, pat, rhs) :: rest => + val valeqs = rest.take(definitions.MaxTupleArity - 1).takeWhile(.isInstanceOf[ValEq]); + assert(!valeqs.isEmpty); + val rest1 = rest.drop(valeqs.length); + val pats = valeqs map { case ValEq(_, pat, _) => pat } + val rhss = valeqs map { case ValEq(_, _, rhs) => rhs } + val defpats = pats map (x => makeBind(x.duplicate)) + val pdefs = List.flatten(List.map2(defpats, rhss)(makePatDef)) + val ids = (pat :: defpats) map makeValue + val rhs1 = makeForYield( + List(ValFrom(pos, makeBind(pat.duplicate), rhs)), + Block(pdefs, makeTupleTerm(ids))) + makeFor(mapName, flatMapName, ValFrom(pos, makeTuple(pat :: pats, true), rhs1) :: rest1, body) } } /** Create tree for for-do comprehension */ - def makeFor(enums: List[Tree], body: Tree): Tree = + def makeFor(enums: List[Enumerator], body: Tree): Tree = makeFor(nme.foreach, nme.foreach, enums, body); /** Create tree for for-yield comprehension */ - def makeForYield(enums: List[Tree], body: Tree): Tree = + def makeForYield(enums: List[Enumerator], body: Tree): Tree = makeFor(nme.map, nme.flatMap, enums, body); /** Create tree for a pattern alternative */ @@ -216,8 +309,11 @@ abstract class TreeBuilder { makeAlternative(List(p, Sequence(List()))); /** Create visitor x match cases> */ - def makeVisitor(cases: List[CaseDef]): Tree = { - val x = freshName(); + def makeVisitor(cases: List[CaseDef]): Tree = makeVisitor(cases, "x$"); + + /** Create visitor x match cases> */ + def makeVisitor(cases: List[CaseDef], prefix: String): Tree = { + val x = freshName(prefix); Function(List(ValDef(Modifiers(PARAM | SYNTHETIC), x, TypeTree(), EmptyTree)), Match(Ident(x), cases)) } @@ -226,6 +322,9 @@ abstract class TreeBuilder { CaseDef(patvarTransformer.transform(pat), guard, rhs); } + /** Create tree for pattern definition */ + def makePatDef(pat: Tree, rhs: Tree): List[Tree] = makePatDef(Modifiers(0), pat, rhs) + /** Create tree for pattern definition */ def makePatDef(mods: Modifiers, pat: Tree, rhs: Tree): List[Tree] = matchVarPattern(pat) match { case Some(Pair(name, tpt)) => @@ -246,7 +345,7 @@ abstract class TreeBuilder { val pat1 = patvarTransformer.transform(pat); val vars = getVariables(pat1); val matchExpr = atPos(pat1.pos){ - Match(rhs, List(CaseDef(pat1, EmptyTree, mkTuple(vars map Ident)))) + Match(rhs, List(CaseDef(pat1, EmptyTree, makeTupleTerm(vars map Ident)))) } vars match { case List() => diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala index 26e941cd87..0e18f0e739 100644 --- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala +++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala @@ -68,6 +68,7 @@ mixin class StdNames requires SymbolTable { val SUPER_PREFIX_STRING = "super$"; val EXPAND_SEPARATOR_STRING = "$$"; val TUPLE_FIELD_PREFIX_STRING = "_"; + val CHECK_IF_REFUTABLE_STRING = "check$ifrefutable$"; def LOCAL(clazz: Symbol) = newTermName(LOCALDUMMY_PREFIX_STRING + clazz.name); def TUPLE_FIELD(index: int) = newTermName(TUPLE_FIELD_PREFIX_STRING + index); diff --git a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala index 7057638749..5c43c640ec 100644 --- a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala +++ b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala @@ -504,6 +504,28 @@ abstract class RefChecks extends InfoTransform { case ex: TypeError => unit.error(tree.pos, ex.getMessage()); } + def isIrrefutable(pat: Tree, seltpe: Type): boolean = { + val result = pat match { + case Apply(_, args) => + val clazz = pat.tpe.symbol; + clazz == seltpe.symbol && + clazz.isClass && (clazz hasFlag CASE) && + List.forall2( + args, + clazz.primaryConstructor.tpe.asSeenFrom(seltpe, clazz).paramTypes)(isIrrefutable) + case Typed(pat, tpt) => + seltpe <:< tpt.tpe + case Ident(nme.WILDCARD) => + true + case Bind(_, pat) => + isIrrefutable(pat, seltpe) + case _ => + false + } + //System.out.println("is irefutable? " + pat + ":" + pat.tpe + " against " + seltpe + ": " + result);//DEBUG + result + } + val savedLocalTyper = localTyper; val sym = tree.symbol; var result = tree; @@ -541,6 +563,15 @@ abstract class RefChecks extends InfoTransform { checkBounds(fn.tpe.typeParams, args map (.tpe)); if (sym.isSourceMethod && sym.hasFlag(CASE)) result = toConstructor; + case Apply( + Select(qual, nme.filter), + List(Function( + List(ValDef(_, pname, tpt, _)), + Match(_, CaseDef(pat1, _, _) :: _)))) + if ((pname startsWith nme.CHECK_IF_REFUTABLE_STRING) && + isIrrefutable(pat1, tpt.tpe)) => + result = qual + case New(tpt) => enterReference(tree.pos, tpt.tpe.symbol); diff --git a/test/files/run/forvaleq.check b/test/files/run/forvaleq.check index 141ac1ebfe..27f5269789 100644 --- a/test/files/run/forvaleq.check +++ b/test/files/run/forvaleq.check @@ -1,4 +1,5 @@ -List(2, 6, 10, 14, 18, 22, 24, 26, 28, 30, 32, 34, 36, 38) -List(2, 6, 10, 14, 18, 22, 24, 26, 28, 30, 32, 34, 36, 38) -List(2, 6, 10, 14, 18, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) -called 21 times +List(2,6,10,14,18,20,22,24,26,28,30,32,34,36,38) +List(2,6,10,14,18,2,2,2,2,2,2,2,2,2,2) +List(2,6,10,14,18,20,22,24,26,28,30,32,34,36,38) +List(2,6,10,14,18,2,2,2,2,2,2,2,2,2,2) +called 20 times diff --git a/test/files/run/forvaleq.scala b/test/files/run/forvaleq.scala index 47bc545d47..a30c659ecc 100644 --- a/test/files/run/forvaleq.scala +++ b/test/files/run/forvaleq.scala @@ -7,7 +7,6 @@ import scala.{List=>L} object Test { // redefine some symbols to make it extra hard class List - class Pair class Tuple2 def List[A](as:A*) = 5 @@ -32,6 +31,20 @@ object Test { Console.println(oddFirstTimesTwo) } + { + // a test case with patterns + + val input = L.range(0,20) + val oddFirstTimesTwo = + for{val x <- input + val xf = firstDigit(x) + val yf = x - firstDigit(x) / 10 + val Pair(a, b) = Pair(xf - yf, xf + yf) + xf % 2 == 1} + yield a + b + Console.println(oddFirstTimesTwo) + } + { // make sure it works on non-Ls @@ -42,7 +55,7 @@ object Test { val xf = firstDigit(x) xf % 2 == 1} yield x*2 - Console.println(oddFirstTimesTwo.toL) + Console.println(oddFirstTimesTwo.toList) } { -- cgit v1.2.3