summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDen Shabalin <den.shabalin@gmail.com>2013-10-31 13:41:53 +0100
committerDen Shabalin <den.shabalin@gmail.com>2013-11-12 14:04:42 +0100
commita4a3ab0d722412b9ecf267b178bb866087867cf9 (patch)
tree2f3a55eeb5c8a8e05fbf302f86ed26201b7b9c7b
parentd73680557d4037bed69bc0ce982566f3915361c3 (diff)
downloadscala-a4a3ab0d722412b9ecf267b178bb866087867cf9.tar.gz
scala-a4a3ab0d722412b9ecf267b178bb866087867cf9.tar.bz2
scala-a4a3ab0d722412b9ecf267b178bb866087867cf9.zip
implement inverse transformation to mkFor
This effectively reconstructs a sequence of enumerators and body from the tree produced by mkFor. This lets to define bi-directional SyntacticFor and SyntacticForYield constructors/extractors to work with for loops. Correctness of the transformation is tested by a scalacheck test that generates a sequence of random enumerators, sugars them into maps/flatMaps/foreach/withFilter calls and reconstructs them back.
-rw-r--r--src/reflect/scala/reflect/api/BuildUtils.scala8
-rw-r--r--src/reflect/scala/reflect/internal/BuildUtils.scala141
-rw-r--r--src/reflect/scala/reflect/internal/TreeGen.scala26
-rw-r--r--test/files/scalacheck/quasiquotes/ForProps.scala38
-rw-r--r--test/files/scalacheck/quasiquotes/Test.scala1
5 files changed, 205 insertions, 9 deletions
diff --git a/src/reflect/scala/reflect/api/BuildUtils.scala b/src/reflect/scala/reflect/api/BuildUtils.scala
index 4d2a1dcc30..cf05aefe72 100644
--- a/src/reflect/scala/reflect/api/BuildUtils.scala
+++ b/src/reflect/scala/reflect/api/BuildUtils.scala
@@ -244,5 +244,13 @@ private[reflect] trait BuildUtils { self: Universe =>
def apply(test: Tree): Tree
def unapply(tree: Tree): Option[(Tree)]
}
+
+ val SyntacticFor: SyntacticForExtractor
+ val SyntacticForYield: SyntacticForExtractor
+
+ trait SyntacticForExtractor {
+ def apply(enums: List[Tree], body: Tree): Tree
+ def unapply(tree: Tree): Option[(List[Tree], Tree)]
+ }
}
}
diff --git a/src/reflect/scala/reflect/internal/BuildUtils.scala b/src/reflect/scala/reflect/internal/BuildUtils.scala
index db4e5685fd..8fc1869dd2 100644
--- a/src/reflect/scala/reflect/internal/BuildUtils.scala
+++ b/src/reflect/scala/reflect/internal/BuildUtils.scala
@@ -463,6 +463,48 @@ trait BuildUtils { self: SymbolTable =>
def unapply(tree: Tree): Option[Tree] = gen.Filter.unapply(tree)
}
+ // abstract over possible alternative representations of no type in valdef
+ protected object EmptyTypTree {
+ def unapply(tree: Tree): Boolean = tree match {
+ case EmptyTree => true
+ case tt: TypeTree if (tt.original == null || tt.original.isEmpty) => true
+ case _ => false
+ }
+ }
+
+ // match a sequence of desugared `val $pat = $value`
+ protected object UnPatSeq {
+ def unapply(trees: List[Tree]): Option[List[(Tree, Tree)]] = trees match {
+ case Nil => Some(Nil)
+ // case q"$mods val ${_}: ${_} = ${MaybeUnchecked(value)} match { case $pat => (..$ids) }" :: tail
+ case ValDef(mods, _, _, Match(MaybeUnchecked(value), CaseDef(pat, EmptyTree, SyntacticTuple(ids)) :: Nil)) :: tail
+ if mods.hasFlag(SYNTHETIC) && mods.hasFlag(ARTIFACT) =>
+ tail.drop(ids.length) match {
+ case UnPatSeq(rest) => Some((pat, value) :: rest)
+ case _ => None
+ }
+ // case q"${_} val $name1: ${_} = ${MaybeUnchecked(value)} match { case $pat => ${Ident(name2)} }" :: UnPatSeq(rest)
+ case ValDef(_, name1, _, Match(MaybeUnchecked(value), CaseDef(pat, EmptyTree, Ident(name2)) :: Nil)) :: UnPatSeq(rest)
+ if name1 == name2 =>
+ Some((pat, value) :: rest)
+ // case q"${_} val $name: ${EmptyTypTree()} = $value" :: UnPatSeq(rest) =>
+ case ValDef(_, name, EmptyTypTree(), value) :: UnPatSeq(rest) =>
+ Some((Bind(name, self.Ident(nme.WILDCARD)), value) :: rest)
+ // case q"${_} val $name: $tpt = $value" :: UnPatSeq(rest) =>
+ case ValDef(_, name, tpt, value) :: UnPatSeq(rest) =>
+ Some((Bind(name, Typed(self.Ident(nme.WILDCARD), tpt)), value) :: rest)
+ case _ => None
+ }
+ }
+
+ // match a sequence of desugared `val $pat = $value` with a tuple in the end
+ protected object UnPatSeqWithRes {
+ def unapply(tree: Tree): Option[(List[(Tree, Tree)], List[Tree])] = tree match {
+ case SyntacticBlock(UnPatSeq(trees) :+ SyntacticTuple(elems)) => Some((trees, elems))
+ case _ => None
+ }
+ }
+
// undo gen.mkSyntheticParam
protected object UnSyntheticParam {
def unapply(tree: Tree): Option[TermName] = tree match {
@@ -483,6 +525,20 @@ trait BuildUtils { self: SymbolTable =>
}
}
+ // undo gen.mkFor:makeClosure
+ protected object UnClosure {
+ def unapply(tree: Tree): Option[(Tree, Tree)] = tree match {
+ case Function(ValDef(Modifiers(PARAM, _, _), name, tpt, EmptyTree) :: Nil, body) =>
+ tpt match {
+ case EmptyTypTree() => Some((Bind(name, self.Ident(nme.WILDCARD)), body))
+ case _ => Some((Bind(name, Typed(self.Ident(nme.WILDCARD), tpt)), body))
+ }
+ case UnVisitor(_, CaseDef(pat, EmptyTree, body) :: Nil) =>
+ Some((pat, body))
+ case _ => None
+ }
+ }
+
// match call to either withFilter or filter
protected object FilterCall {
def unapply(tree: Tree): Option[(Tree,Tree)] = tree match {
@@ -492,6 +548,18 @@ trait BuildUtils { self: SymbolTable =>
}
}
+ // transform a chain of withFilter calls into a sequence of for filters
+ protected object UnFilter {
+ def unapply(tree: Tree): Some[(Tree, List[Tree])] = tree match {
+ case UnCheckIfRefutable(_, _) =>
+ Some((tree, Nil))
+ case FilterCall(UnFilter(rhs, rest), UnClosure(_, test)) =>
+ Some((rhs, rest :+ SyntacticFilter(test)))
+ case _ =>
+ Some((tree, Nil))
+ }
+ }
+
// undo gen.mkCheckIfRefutable
protected object UnCheckIfRefutable {
def unapply(tree: Tree): Option[(Tree, Tree)] = tree match {
@@ -504,6 +572,79 @@ trait BuildUtils { self: SymbolTable =>
}
}
+ // undo gen.mkFor:makeCombination accounting for possible extra implicit argument
+ protected class UnForCombination(name: TermName) {
+ def unapply(tree: Tree) = tree match {
+ case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: Nil)
+ if name == meth && sel.hasAttachment[ForAttachment.type] =>
+ Some(lhs, f)
+ case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: _ :: Nil)
+ if name == meth && sel.hasAttachment[ForAttachment.type] =>
+ Some(lhs, f)
+ case _ => None
+ }
+ }
+ protected object UnMap extends UnForCombination(nme.map)
+ protected object UnForeach extends UnForCombination(nme.foreach)
+ protected object UnFlatMap extends UnForCombination(nme.flatMap)
+
+ // undo desugaring done in gen.mkFor
+ protected object UnFor {
+ def unapply(tree: Tree): Option[(List[Tree], Tree)] = {
+ val interm = tree match {
+ case UnFlatMap(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) =>
+ Some(((pat, rhs), filters ::: rest, body))
+ case UnForeach(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) =>
+ Some(((pat, rhs), filters ::: rest, body))
+ case UnMap(UnFilter(rhs, filters), UnClosure(pat, cbody)) =>
+ Some(((pat, rhs), filters, gen.Yield(cbody)))
+ case UnForeach(UnFilter(rhs, filters), UnClosure(pat, cbody)) =>
+ Some(((pat, rhs), filters, cbody))
+ case _ => None
+ }
+ interm.flatMap {
+ case ((Bind(_, SyntacticTuple(_)) | SyntacticTuple(_),
+ UnFor(SyntacticValFrom(pat, rhs) :: innerRest, gen.Yield(UnPatSeqWithRes(pats, elems2)))),
+ outerRest, fbody) =>
+ val valeqs = pats.map { case (pat, rhs) => SyntacticValEq(pat, rhs) }
+ Some((SyntacticValFrom(pat, rhs) :: innerRest ::: valeqs ::: outerRest, fbody))
+ case ((pat, rhs), filters, body) =>
+ Some((SyntacticValFrom(pat, rhs) :: filters, body))
+ }
+ }
+ }
+
+ // check that enumerators are valid
+ protected def mkEnumerators(enums: List[Tree]): List[Tree] = {
+ require(enums.nonEmpty, "enumerators can't be empty")
+ enums.head match {
+ case SyntacticValFrom(_, _) =>
+ case t => throw new IllegalArgumentException(s"$t is not a valid fist enumerator of for loop")
+ }
+ enums.tail.foreach {
+ case SyntacticValEq(_, _) | SyntacticValFrom(_, _) | SyntacticFilter(_) =>
+ case t => throw new IllegalArgumentException(s"$t is not a valid representation of a for loop enumerator")
+ }
+ enums
+ }
+
+ object SyntacticFor extends SyntacticForExtractor {
+ def apply(enums: List[Tree], body: Tree): Tree = gen.mkFor(mkEnumerators(enums), body)
+ def unapply(tree: Tree) = tree match {
+ case UnFor(enums, gen.Yield(body)) => None
+ case UnFor(enums, body) => Some((enums, body))
+ case _ => None
+ }
+ }
+
+ object SyntacticForYield extends SyntacticForExtractor {
+ def apply(enums: List[Tree], body: Tree): Tree = gen.mkFor(mkEnumerators(enums), gen.Yield(body))
+ def unapply(tree: Tree) = tree match {
+ case UnFor(enums, gen.Yield(body)) => Some((enums, body))
+ case _ => None
+ }
+ }
+
// use typetree's original instead of typetree itself
protected object MaybeTypeTreeOriginal {
def unapply(tree: Tree): Some[Tree] = tree match {
diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala
index d5eff8a9b8..baa8ac27f6 100644
--- a/src/reflect/scala/reflect/internal/TreeGen.scala
+++ b/src/reflect/scala/reflect/internal/TreeGen.scala
@@ -574,7 +574,8 @@ abstract class TreeGen extends macros.TreeBuilder {
* the limits given by pat and body.
*/
def makeClosure(pos: Position, pat: Tree, body: Tree): Tree = {
- def splitpos = wrappingPos(List(pat, body)).withPoint(pos.point).makeTransparent
+ def wrapped = wrappingPos(List(pat, body))
+ def splitpos = (if (pos != NoPosition) wrapped.withPoint(pos.point) else pos).makeTransparent
matchVarPattern(pat) match {
case Some((name, tpt)) =>
Function(
@@ -590,7 +591,8 @@ abstract class TreeGen extends macros.TreeBuilder {
/* Make an application qual.meth(pat => body) positioned at `pos`.
*/
def makeCombination(pos: Position, meth: TermName, qual: Tree, pat: Tree, body: Tree): Tree =
- Apply(Select(qual, meth) setPos qual.pos, List(makeClosure(pos, pat, body))) setPos pos
+ Apply(Select(qual, meth) setPos qual.pos updateAttachment ForAttachment,
+ List(makeClosure(pos, pat, body))) setPos pos
/* If `pat` is not yet a `Bind` wrap it in one with a fresh name */
def makeBind(pat: Tree): Tree = pat match {
@@ -604,13 +606,15 @@ abstract class TreeGen extends macros.TreeBuilder {
}
/* The position of the closure that starts with generator at position `genpos`. */
- def closurePos(genpos: Position) = {
- val end = body.pos match {
- case NoPosition => genpos.point
- case bodypos => bodypos.end
+ def closurePos(genpos: Position) =
+ if (genpos == NoPosition) NoPosition
+ else {
+ val end = body.pos match {
+ case NoPosition => genpos.point
+ case bodypos => bodypos.end
+ }
+ rangePos(genpos.source, genpos.start, genpos.point, end)
}
- rangePos(genpos.source ,genpos.start, genpos.point, end)
- }
enums match {
case (t @ ValFrom(pat, rhs)) :: Nil =>
@@ -634,10 +638,14 @@ abstract class TreeGen extends macros.TreeBuilder {
List(ValFrom(defpat1, rhs).setPos(t.pos)),
Yield(Block(pdefs, atPos(wrappingPos(ids)) { mkTuple(ids) }) setPos wrappingPos(pdefs)))
val allpats = (pat :: pats) map (_.duplicate)
- val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end))
+ val pos1 =
+ if (t.pos == NoPosition) NoPosition
+ else rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end)
+ val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(pos1)
mkFor(vfrom1 :: rest1, sugarBody)
case _ =>
EmptyTree //may happen for erroneous input
+
}
}
diff --git a/test/files/scalacheck/quasiquotes/ForProps.scala b/test/files/scalacheck/quasiquotes/ForProps.scala
new file mode 100644
index 0000000000..234f4d10eb
--- /dev/null
+++ b/test/files/scalacheck/quasiquotes/ForProps.scala
@@ -0,0 +1,38 @@
+import org.scalacheck._, Prop._, Gen._, Arbitrary._
+import scala.reflect.runtime.universe._, Flag._, build.{Ident => _, _}
+
+object ForProps extends QuasiquoteProperties("for") {
+ case class ForEnums(val value: List[Tree])
+
+ def genSimpleBind: Gen[Bind] =
+ for(name <- genTermName)
+ yield pq"$name @ _"
+
+ def genForFilter: Gen[Tree] =
+ for(cond <- genIdent(genTermName))
+ yield SyntacticFilter(cond)
+
+ def genForFrom: Gen[Tree] =
+ for(lhs <- genSimpleBind; rhs <- genIdent(genTermName))
+ yield SyntacticValFrom(lhs, rhs)
+
+ def genForEq: Gen[Tree] =
+ for(lhs <- genSimpleBind; rhs <- genIdent(genTermName))
+ yield SyntacticValEq(lhs, rhs)
+
+ def genForEnums(size: Int): Gen[ForEnums] =
+ for(first <- genForFrom; rest <- listOfN(size, oneOf(genForFrom, genForFilter, genForEq)))
+ yield new ForEnums(first :: rest)
+
+ implicit val arbForEnums: Arbitrary[ForEnums] = arbitrarySized(genForEnums)
+
+ property("construct-reconstruct for") = forAll { (enums: ForEnums, body: Tree) =>
+ val SyntacticFor(recoveredEnums, recoveredBody) = SyntacticFor(enums.value, body)
+ recoveredEnums ≈ enums.value && recoveredBody ≈ body
+ }
+
+ property("construct-reconstruct for-yield") = forAll { (enums: ForEnums, body: Tree) =>
+ val SyntacticForYield(recoveredEnums, recoveredBody) = SyntacticForYield(enums.value, body)
+ recoveredEnums ≈ enums.value && recoveredBody ≈ body
+ }
+} \ No newline at end of file
diff --git a/test/files/scalacheck/quasiquotes/Test.scala b/test/files/scalacheck/quasiquotes/Test.scala
index 73cac0368c..8b1e779ab2 100644
--- a/test/files/scalacheck/quasiquotes/Test.scala
+++ b/test/files/scalacheck/quasiquotes/Test.scala
@@ -12,5 +12,6 @@ object Test extends Properties("quasiquotes") {
include(DefinitionConstructionProps)
include(DefinitionDeconstructionProps)
include(DeprecationProps)
+ include(ForProps)
include(TypecheckedProps)
}