From a4a3ab0d722412b9ecf267b178bb866087867cf9 Mon Sep 17 00:00:00 2001 From: Den Shabalin Date: Thu, 31 Oct 2013 13:41:53 +0100 Subject: implement inverse transformation to mkFor This effectively reconstructs a sequence of enumerators and body from the tree produced by mkFor. This lets to define bi-directional SyntacticFor and SyntacticForYield constructors/extractors to work with for loops. Correctness of the transformation is tested by a scalacheck test that generates a sequence of random enumerators, sugars them into maps/flatMaps/foreach/withFilter calls and reconstructs them back. --- src/reflect/scala/reflect/api/BuildUtils.scala | 8 ++ .../scala/reflect/internal/BuildUtils.scala | 141 +++++++++++++++++++++ src/reflect/scala/reflect/internal/TreeGen.scala | 26 ++-- test/files/scalacheck/quasiquotes/ForProps.scala | 38 ++++++ test/files/scalacheck/quasiquotes/Test.scala | 1 + 5 files changed, 205 insertions(+), 9 deletions(-) create mode 100644 test/files/scalacheck/quasiquotes/ForProps.scala diff --git a/src/reflect/scala/reflect/api/BuildUtils.scala b/src/reflect/scala/reflect/api/BuildUtils.scala index 4d2a1dcc30..cf05aefe72 100644 --- a/src/reflect/scala/reflect/api/BuildUtils.scala +++ b/src/reflect/scala/reflect/api/BuildUtils.scala @@ -244,5 +244,13 @@ private[reflect] trait BuildUtils { self: Universe => def apply(test: Tree): Tree def unapply(tree: Tree): Option[(Tree)] } + + val SyntacticFor: SyntacticForExtractor + val SyntacticForYield: SyntacticForExtractor + + trait SyntacticForExtractor { + def apply(enums: List[Tree], body: Tree): Tree + def unapply(tree: Tree): Option[(List[Tree], Tree)] + } } } diff --git a/src/reflect/scala/reflect/internal/BuildUtils.scala b/src/reflect/scala/reflect/internal/BuildUtils.scala index db4e5685fd..8fc1869dd2 100644 --- a/src/reflect/scala/reflect/internal/BuildUtils.scala +++ b/src/reflect/scala/reflect/internal/BuildUtils.scala @@ -463,6 +463,48 @@ trait BuildUtils { self: SymbolTable => def unapply(tree: Tree): Option[Tree] = gen.Filter.unapply(tree) } + // abstract over possible alternative representations of no type in valdef + protected object EmptyTypTree { + def unapply(tree: Tree): Boolean = tree match { + case EmptyTree => true + case tt: TypeTree if (tt.original == null || tt.original.isEmpty) => true + case _ => false + } + } + + // match a sequence of desugared `val $pat = $value` + protected object UnPatSeq { + def unapply(trees: List[Tree]): Option[List[(Tree, Tree)]] = trees match { + case Nil => Some(Nil) + // case q"$mods val ${_}: ${_} = ${MaybeUnchecked(value)} match { case $pat => (..$ids) }" :: tail + case ValDef(mods, _, _, Match(MaybeUnchecked(value), CaseDef(pat, EmptyTree, SyntacticTuple(ids)) :: Nil)) :: tail + if mods.hasFlag(SYNTHETIC) && mods.hasFlag(ARTIFACT) => + tail.drop(ids.length) match { + case UnPatSeq(rest) => Some((pat, value) :: rest) + case _ => None + } + // case q"${_} val $name1: ${_} = ${MaybeUnchecked(value)} match { case $pat => ${Ident(name2)} }" :: UnPatSeq(rest) + case ValDef(_, name1, _, Match(MaybeUnchecked(value), CaseDef(pat, EmptyTree, Ident(name2)) :: Nil)) :: UnPatSeq(rest) + if name1 == name2 => + Some((pat, value) :: rest) + // case q"${_} val $name: ${EmptyTypTree()} = $value" :: UnPatSeq(rest) => + case ValDef(_, name, EmptyTypTree(), value) :: UnPatSeq(rest) => + Some((Bind(name, self.Ident(nme.WILDCARD)), value) :: rest) + // case q"${_} val $name: $tpt = $value" :: UnPatSeq(rest) => + case ValDef(_, name, tpt, value) :: UnPatSeq(rest) => + Some((Bind(name, Typed(self.Ident(nme.WILDCARD), tpt)), value) :: rest) + case _ => None + } + } + + // match a sequence of desugared `val $pat = $value` with a tuple in the end + protected object UnPatSeqWithRes { + def unapply(tree: Tree): Option[(List[(Tree, Tree)], List[Tree])] = tree match { + case SyntacticBlock(UnPatSeq(trees) :+ SyntacticTuple(elems)) => Some((trees, elems)) + case _ => None + } + } + // undo gen.mkSyntheticParam protected object UnSyntheticParam { def unapply(tree: Tree): Option[TermName] = tree match { @@ -483,6 +525,20 @@ trait BuildUtils { self: SymbolTable => } } + // undo gen.mkFor:makeClosure + protected object UnClosure { + def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { + case Function(ValDef(Modifiers(PARAM, _, _), name, tpt, EmptyTree) :: Nil, body) => + tpt match { + case EmptyTypTree() => Some((Bind(name, self.Ident(nme.WILDCARD)), body)) + case _ => Some((Bind(name, Typed(self.Ident(nme.WILDCARD), tpt)), body)) + } + case UnVisitor(_, CaseDef(pat, EmptyTree, body) :: Nil) => + Some((pat, body)) + case _ => None + } + } + // match call to either withFilter or filter protected object FilterCall { def unapply(tree: Tree): Option[(Tree,Tree)] = tree match { @@ -492,6 +548,18 @@ trait BuildUtils { self: SymbolTable => } } + // transform a chain of withFilter calls into a sequence of for filters + protected object UnFilter { + def unapply(tree: Tree): Some[(Tree, List[Tree])] = tree match { + case UnCheckIfRefutable(_, _) => + Some((tree, Nil)) + case FilterCall(UnFilter(rhs, rest), UnClosure(_, test)) => + Some((rhs, rest :+ SyntacticFilter(test))) + case _ => + Some((tree, Nil)) + } + } + // undo gen.mkCheckIfRefutable protected object UnCheckIfRefutable { def unapply(tree: Tree): Option[(Tree, Tree)] = tree match { @@ -504,6 +572,79 @@ trait BuildUtils { self: SymbolTable => } } + // undo gen.mkFor:makeCombination accounting for possible extra implicit argument + protected class UnForCombination(name: TermName) { + def unapply(tree: Tree) = tree match { + case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: Nil) + if name == meth && sel.hasAttachment[ForAttachment.type] => + Some(lhs, f) + case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: _ :: Nil) + if name == meth && sel.hasAttachment[ForAttachment.type] => + Some(lhs, f) + case _ => None + } + } + protected object UnMap extends UnForCombination(nme.map) + protected object UnForeach extends UnForCombination(nme.foreach) + protected object UnFlatMap extends UnForCombination(nme.flatMap) + + // undo desugaring done in gen.mkFor + protected object UnFor { + def unapply(tree: Tree): Option[(List[Tree], Tree)] = { + val interm = tree match { + case UnFlatMap(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) => + Some(((pat, rhs), filters ::: rest, body)) + case UnForeach(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) => + Some(((pat, rhs), filters ::: rest, body)) + case UnMap(UnFilter(rhs, filters), UnClosure(pat, cbody)) => + Some(((pat, rhs), filters, gen.Yield(cbody))) + case UnForeach(UnFilter(rhs, filters), UnClosure(pat, cbody)) => + Some(((pat, rhs), filters, cbody)) + case _ => None + } + interm.flatMap { + case ((Bind(_, SyntacticTuple(_)) | SyntacticTuple(_), + UnFor(SyntacticValFrom(pat, rhs) :: innerRest, gen.Yield(UnPatSeqWithRes(pats, elems2)))), + outerRest, fbody) => + val valeqs = pats.map { case (pat, rhs) => SyntacticValEq(pat, rhs) } + Some((SyntacticValFrom(pat, rhs) :: innerRest ::: valeqs ::: outerRest, fbody)) + case ((pat, rhs), filters, body) => + Some((SyntacticValFrom(pat, rhs) :: filters, body)) + } + } + } + + // check that enumerators are valid + protected def mkEnumerators(enums: List[Tree]): List[Tree] = { + require(enums.nonEmpty, "enumerators can't be empty") + enums.head match { + case SyntacticValFrom(_, _) => + case t => throw new IllegalArgumentException(s"$t is not a valid fist enumerator of for loop") + } + enums.tail.foreach { + case SyntacticValEq(_, _) | SyntacticValFrom(_, _) | SyntacticFilter(_) => + case t => throw new IllegalArgumentException(s"$t is not a valid representation of a for loop enumerator") + } + enums + } + + object SyntacticFor extends SyntacticForExtractor { + def apply(enums: List[Tree], body: Tree): Tree = gen.mkFor(mkEnumerators(enums), body) + def unapply(tree: Tree) = tree match { + case UnFor(enums, gen.Yield(body)) => None + case UnFor(enums, body) => Some((enums, body)) + case _ => None + } + } + + object SyntacticForYield extends SyntacticForExtractor { + def apply(enums: List[Tree], body: Tree): Tree = gen.mkFor(mkEnumerators(enums), gen.Yield(body)) + def unapply(tree: Tree) = tree match { + case UnFor(enums, gen.Yield(body)) => Some((enums, body)) + case _ => None + } + } + // use typetree's original instead of typetree itself protected object MaybeTypeTreeOriginal { def unapply(tree: Tree): Some[Tree] = tree match { diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala index d5eff8a9b8..baa8ac27f6 100644 --- a/src/reflect/scala/reflect/internal/TreeGen.scala +++ b/src/reflect/scala/reflect/internal/TreeGen.scala @@ -574,7 +574,8 @@ abstract class TreeGen extends macros.TreeBuilder { * the limits given by pat and body. */ def makeClosure(pos: Position, pat: Tree, body: Tree): Tree = { - def splitpos = wrappingPos(List(pat, body)).withPoint(pos.point).makeTransparent + def wrapped = wrappingPos(List(pat, body)) + def splitpos = (if (pos != NoPosition) wrapped.withPoint(pos.point) else pos).makeTransparent matchVarPattern(pat) match { case Some((name, tpt)) => Function( @@ -590,7 +591,8 @@ abstract class TreeGen extends macros.TreeBuilder { /* Make an application qual.meth(pat => body) positioned at `pos`. */ def makeCombination(pos: Position, meth: TermName, qual: Tree, pat: Tree, body: Tree): Tree = - Apply(Select(qual, meth) setPos qual.pos, List(makeClosure(pos, pat, body))) setPos pos + Apply(Select(qual, meth) setPos qual.pos updateAttachment ForAttachment, + List(makeClosure(pos, pat, body))) setPos pos /* If `pat` is not yet a `Bind` wrap it in one with a fresh name */ def makeBind(pat: Tree): Tree = pat match { @@ -604,13 +606,15 @@ abstract class TreeGen extends macros.TreeBuilder { } /* The position of the closure that starts with generator at position `genpos`. */ - def closurePos(genpos: Position) = { - val end = body.pos match { - case NoPosition => genpos.point - case bodypos => bodypos.end + def closurePos(genpos: Position) = + if (genpos == NoPosition) NoPosition + else { + val end = body.pos match { + case NoPosition => genpos.point + case bodypos => bodypos.end + } + rangePos(genpos.source, genpos.start, genpos.point, end) } - rangePos(genpos.source ,genpos.start, genpos.point, end) - } enums match { case (t @ ValFrom(pat, rhs)) :: Nil => @@ -634,10 +638,14 @@ abstract class TreeGen extends macros.TreeBuilder { List(ValFrom(defpat1, rhs).setPos(t.pos)), Yield(Block(pdefs, atPos(wrappingPos(ids)) { mkTuple(ids) }) setPos wrappingPos(pdefs))) val allpats = (pat :: pats) map (_.duplicate) - val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end)) + val pos1 = + if (t.pos == NoPosition) NoPosition + else rangePos(t.pos.source, t.pos.start, t.pos.point, rhs1.pos.end) + val vfrom1 = ValFrom(atPos(wrappingPos(allpats)) { mkTuple(allpats) }, rhs1).setPos(pos1) mkFor(vfrom1 :: rest1, sugarBody) case _ => EmptyTree //may happen for erroneous input + } } diff --git a/test/files/scalacheck/quasiquotes/ForProps.scala b/test/files/scalacheck/quasiquotes/ForProps.scala new file mode 100644 index 0000000000..234f4d10eb --- /dev/null +++ b/test/files/scalacheck/quasiquotes/ForProps.scala @@ -0,0 +1,38 @@ +import org.scalacheck._, Prop._, Gen._, Arbitrary._ +import scala.reflect.runtime.universe._, Flag._, build.{Ident => _, _} + +object ForProps extends QuasiquoteProperties("for") { + case class ForEnums(val value: List[Tree]) + + def genSimpleBind: Gen[Bind] = + for(name <- genTermName) + yield pq"$name @ _" + + def genForFilter: Gen[Tree] = + for(cond <- genIdent(genTermName)) + yield SyntacticFilter(cond) + + def genForFrom: Gen[Tree] = + for(lhs <- genSimpleBind; rhs <- genIdent(genTermName)) + yield SyntacticValFrom(lhs, rhs) + + def genForEq: Gen[Tree] = + for(lhs <- genSimpleBind; rhs <- genIdent(genTermName)) + yield SyntacticValEq(lhs, rhs) + + def genForEnums(size: Int): Gen[ForEnums] = + for(first <- genForFrom; rest <- listOfN(size, oneOf(genForFrom, genForFilter, genForEq))) + yield new ForEnums(first :: rest) + + implicit val arbForEnums: Arbitrary[ForEnums] = arbitrarySized(genForEnums) + + property("construct-reconstruct for") = forAll { (enums: ForEnums, body: Tree) => + val SyntacticFor(recoveredEnums, recoveredBody) = SyntacticFor(enums.value, body) + recoveredEnums ≈ enums.value && recoveredBody ≈ body + } + + property("construct-reconstruct for-yield") = forAll { (enums: ForEnums, body: Tree) => + val SyntacticForYield(recoveredEnums, recoveredBody) = SyntacticForYield(enums.value, body) + recoveredEnums ≈ enums.value && recoveredBody ≈ body + } +} \ No newline at end of file diff --git a/test/files/scalacheck/quasiquotes/Test.scala b/test/files/scalacheck/quasiquotes/Test.scala index 73cac0368c..8b1e779ab2 100644 --- a/test/files/scalacheck/quasiquotes/Test.scala +++ b/test/files/scalacheck/quasiquotes/Test.scala @@ -12,5 +12,6 @@ object Test extends Properties("quasiquotes") { include(DefinitionConstructionProps) include(DefinitionDeconstructionProps) include(DeprecationProps) + include(ForProps) include(TypecheckedProps) } -- cgit v1.2.3