From 0200375e670b5dcc865c8636faf00ae5e767a81b Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Mon, 27 Jan 2014 18:38:39 +0100 Subject: Addresses feedback from Jason 1. Adds tests for new synthetic unit stripping. 2. Marks implementation-specific parts of Holes as private. 3. Trims description of iterated method a bit. 4. Provides a bit more clear wrapper for q interpolator. 5. Refactors SyntacticBlock, adds documentation. 6. Makes q"{ ..$Nil }" return q"" to be consist with extractor. --- .../scala/tools/nsc/ast/parser/Parsers.scala | 2 +- .../scala/tools/reflect/quasiquotes/Holes.scala | 49 ++++++++++------------ .../scala/tools/reflect/quasiquotes/Parsers.scala | 16 ++++--- .../scala/tools/reflect/quasiquotes/Reifiers.scala | 4 +- .../scala/reflect/internal/BuildUtils.scala | 38 ++++++++++++++--- src/reflect/scala/reflect/internal/TreeGen.scala | 18 ++++++-- .../quasiquotes/TermConstructionProps.scala | 25 ++++++++++- 7 files changed, 105 insertions(+), 47 deletions(-) diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala index 0728fff74f..e6b353f82f 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala @@ -33,7 +33,7 @@ trait ParsersCommon extends ScannersCommon { self => import global.{currentUnit => _, _} def newLiteral(const: Any) = Literal(Constant(const)) - def literalUnit = newLiteral(()) + def literalUnit = gen.mkSyntheticUnit() /** This is now an abstract class, only to work around the optimizer: * methods in traits are never inlined. diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala b/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala index 057a168b9b..2027d43264 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala @@ -31,19 +31,19 @@ trait Holes { self: Quasiquotes => import definitions._ import universeTypes._ - protected lazy val IterableTParam = IterableClass.typeParams(0).asType.toType - protected def inferParamImplicit(tfun: Type, targ: Type) = c.inferImplicitValue(appliedType(tfun, List(targ)), silent = true) - protected def inferLiftable(tpe: Type): Tree = inferParamImplicit(liftableType, tpe) - protected def inferUnliftable(tpe: Type): Tree = inferParamImplicit(unliftableType, tpe) - protected def isLiftableType(tpe: Type) = inferLiftable(tpe) != EmptyTree - protected def isNativeType(tpe: Type) = + private lazy val IterableTParam = IterableClass.typeParams(0).asType.toType + private def inferParamImplicit(tfun: Type, targ: Type) = c.inferImplicitValue(appliedType(tfun, List(targ)), silent = true) + private def inferLiftable(tpe: Type): Tree = inferParamImplicit(liftableType, tpe) + private def inferUnliftable(tpe: Type): Tree = inferParamImplicit(unliftableType, tpe) + private def isLiftableType(tpe: Type) = inferLiftable(tpe) != EmptyTree + private def isNativeType(tpe: Type) = (tpe <:< treeType) || (tpe <:< nameType) || (tpe <:< modsType) || (tpe <:< flagsType) || (tpe <:< symbolType) - protected def isBottomType(tpe: Type) = + private def isBottomType(tpe: Type) = tpe <:< NothingClass.tpe || tpe <:< NullClass.tpe - protected def extractIterableTParam(tpe: Type) = + private def extractIterableTParam(tpe: Type) = IterableTParam.asSeenFrom(tpe, IterableClass) - protected def stripIterable(tpe: Type, limit: Option[Cardinality] = None): (Cardinality, Type) = + private def stripIterable(tpe: Type, limit: Option[Cardinality] = None): (Cardinality, Type) = if (limit.map { _ == NoDot }.getOrElse { false }) (NoDot, tpe) else if (tpe != null && !isIterableType(tpe)) (NoDot, tpe) else if (isBottomType(tpe)) (NoDot, tpe) @@ -52,7 +52,7 @@ trait Holes { self: Quasiquotes => val (card, innerTpe) = stripIterable(targ, limit.map { _.pred }) (card.succ, innerTpe) } - protected def iterableTypeFromCard(n: Cardinality, tpe: Type): Type = { + private def iterableTypeFromCard(n: Cardinality, tpe: Type): Type = { if (n == NoDot) tpe else appliedType(IterableClass.toType, List(iterableTypeFromCard(n.pred, tpe))) } @@ -96,7 +96,7 @@ trait Holes { self: Quasiquotes => val cardinality = stripIterable(tpe)._1 - protected def cantSplice(): Nothing = { + private def cantSplice(): Nothing = { val (iterableCard, iterableType) = stripIterable(splicee.tpe) val holeCardMsg = if (card != NoDot) s" with $card" else "" val action = "splice " + splicee.tpe + holeCardMsg @@ -112,22 +112,22 @@ trait Holes { self: Quasiquotes => c.abort(splicee.pos, s"Can't $action, $advice") } - protected def lifted(tpe: Type)(tree: Tree): Tree = { + private def lifted(tpe: Type)(tree: Tree): Tree = { val lifter = inferLiftable(tpe) assert(lifter != EmptyTree, s"couldnt find a liftable for $tpe") val lifted = Apply(lifter, List(tree)) atPos(tree.pos)(lifted) } - protected def toStats(tree: Tree): Tree = + private def toStats(tree: Tree): Tree = // q"$u.build.toStats($tree)" Apply(Select(Select(u, nme.build), nme.toStats), tree :: Nil) - protected def toList(tree: Tree, tpe: Type): Tree = + private def toList(tree: Tree, tpe: Type): Tree = if (isListType(tpe)) tree else Select(tree, nme.toList) - protected def mapF(tree: Tree, f: Tree => Tree): Tree = + private def mapF(tree: Tree, f: Tree => Tree): Tree = if (f(Ident(TermName("x"))) equalsStructure Ident(TermName("x"))) tree else { val x: TermName = c.freshName() @@ -137,12 +137,12 @@ trait Holes { self: Quasiquotes => f(Ident(x))) :: Nil) } - protected object IterableType { + private object IterableType { def unapply(tpe: Type): Option[Type] = if (isIterableType(tpe)) Some(extractIterableTParam(tpe)) else None } - protected object LiftedType { + private object LiftedType { def unapply(tpe: Type): Option[Tree => Tree] = if (tpe <:< treeType) Some(t => t) else if (isLiftableType(tpe)) Some(lifted(tpe)(_)) @@ -155,23 +155,18 @@ trait Holes { self: Quasiquotes => * * input output for T <: Tree output for T: Liftable * - * ..${x: List[T]} x x.map(lift) * ..${x: Iterable[T]} x.toList x.toList.map(lift) * ..${x: T} toStats(x) toStats(lift(x)) * - * ...${x: List[List[T]]} x x.map { _.map(lift) } } - * ...${x: List[Iterable[T]} x.map { _.toList } x.map { _.toList.map(lift) } } - * ...${x: List[T]} x.map { toStats(_) } x.map { toStats(lift(_)) } - * ...${x: Iterable[List[T]]} x.toList x.toList.map { _.map(lift) } * ...${x: Iterable[Iterable[T]]} x.toList { _.toList } x.toList.map { _.toList.map(lift) } * ...${x: Iterable[T]} x.toList.map { toStats(_) } x.toList.map { toStats(lift(_)) } - * ...${x: T} toStats(x).map { toStats(_) } toStats(lift(x)).map(toStats) + * ...${x: T} toStats(x).map { toStats(_) } toStats(lift(x)).map { toStats(_) } * - * As you can see table is quite repetetive. Middle column is equivalent to the right one with - * lift function equal to identity. Cases with List are equivalent to Iterated ones (due to - * the fact that toList method call is just an identity we can omit it altogether.) + * For optimization purposes `x.toList` is represented as just `x` if it is statically known that + * x is not just an Iterable[T] but a List[T]. Similarly no mapping is performed if mapping function is + * known to be an identity. */ - protected def iterated(card: Cardinality, tree: Tree, tpe: Type): Tree = (card, tpe) match { + private def iterated(card: Cardinality, tree: Tree, tpe: Type): Tree = (card, tpe) match { case (DotDot, tpe @ IterableType(LiftedType(lift))) => mapF(toList(tree, tpe), lift) case (DotDot, LiftedType(lift)) => toStats(lift(tree)) case (DotDotDot, tpe @ IterableType(inner)) => mapF(toList(tree, tpe), t => iterated(DotDot, t, inner)) diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala b/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala index 1bd9323752..5303d5eb58 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala @@ -160,15 +160,19 @@ trait Parsers { self: Quasiquotes => } } - object TermParser extends Parser { - def entryPoint = { parser => - parser.templateOrTopStatSeq() match { - case head :: Nil => Block(Nil, head) - case lst => gen.mkTreeOrBlock(lst) - } + /** Wrapper around tree parsed in q"..." quote. Needed to support ..$ splicing on top-level. */ + object Q { + def apply(tree: Tree): Block = Block(Nil, tree).updateAttachment(Q) + def unapply(tree: Tree): Option[Tree] = tree match { + case Block(Nil, contents) if tree.hasAttachment[Q.type] => Some(contents) + case _ => None } } + object TermParser extends Parser { + def entryPoint = parser => Q(gen.mkTreeOrBlock(parser.templateOrTopStatSeq())) + } + object TypeParser extends Parser { def entryPoint = _.typ() } diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala b/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala index 5246592647..45bc2d776c 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala @@ -185,9 +185,9 @@ trait Reifiers { self: Quasiquotes => reifyBuildCall(nme.SyntacticFunction, args, body) case SyntacticIdent(name, isBackquoted) => reifyBuildCall(nme.SyntacticIdent, name, isBackquoted) - case Block(Nil, Placeholder(Hole(tree, DotDot))) => + case Q(Placeholder(Hole(tree, DotDot))) => mirrorBuildCall(nme.SyntacticBlock, tree) - case Block(Nil, other) => + case Q(other) => reifyTree(other) // Syntactic block always matches so we have to be careful // not to cause infinite recursion. diff --git a/src/reflect/scala/reflect/internal/BuildUtils.scala b/src/reflect/scala/reflect/internal/BuildUtils.scala index 46217f768c..ce77e95917 100644 --- a/src/reflect/scala/reflect/internal/BuildUtils.scala +++ b/src/reflect/scala/reflect/internal/BuildUtils.scala @@ -383,15 +383,41 @@ trait BuildUtils { self: SymbolTable => } } + object SyntheticUnit { + def unapply(tree: Tree): Boolean = tree match { + case Literal(Constant(())) if tree.hasAttachment[SyntheticUnitAttachment.type] => true + case _ => false + } + } + + /** Syntactic combinator that abstracts over Block tree. + * + * Apart from providing a more straightforward api that exposes + * block as a list of elements rather than (stats, expr) pair + * it also: + * + * 1. Treats of q"" (empty tree) as zero-element block. + * + * 2. Strips trailing synthetic units which are inserted by the + * compiler if the block ends with a definition rather + * than an expression. + * + * 3. Matches non-block term trees and recognizes them as + * single-element blocks for sake of consistency with + * compiler's default to treat single-element blocks with + * expressions as just expressions. + */ object SyntacticBlock extends SyntacticBlockExtractor { - def apply(stats: List[Tree]): Tree = gen.mkBlock(stats) + def apply(stats: List[Tree]): Tree = + if (stats.isEmpty) EmptyTree + else gen.mkBlock(stats) def unapply(tree: Tree): Option[List[Tree]] = tree match { - case EmptyTree => Some(Nil) - case self.Block(stats, expr) if expr.hasAttachment[SyntheticUnitAttachment.type] => Some(stats) - case self.Block(stats, expr) => Some(stats :+ expr) - case _ if tree.isTerm => Some(tree :: Nil) - case _ => None + case self.Block(stats, SyntheticUnit()) => Some(stats) + case self.Block(stats, expr) => Some(stats :+ expr) + case EmptyTree => Some(Nil) + case _ if tree.isTerm => Some(tree :: Nil) + case _ => None } } diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala index 0e1eb7d3a9..e602a12175 100644 --- a/src/reflect/scala/reflect/internal/TreeGen.scala +++ b/src/reflect/scala/reflect/internal/TreeGen.scala @@ -342,7 +342,7 @@ abstract class TreeGen extends macros.TreeBuilder { } param } - + val (edefs, rest) = body span treeInfo.isEarlyDef val (evdefs, etdefs) = edefs partition treeInfo.isEarlyValDef val gvdefs = evdefs map { @@ -381,11 +381,11 @@ abstract class TreeGen extends macros.TreeBuilder { } constr foreach (ensureNonOverlapping(_, parents ::: gvdefs, focus = false)) // Field definitions for the class - remove defaults. - + val fieldDefs = vparamss.flatten map (vd => { val field = copyValDef(vd)(mods = vd.mods &~ DEFAULTPARAM, rhs = EmptyTree) // Prevent overlapping of `field` end's position with default argument's start position. - // This is needed for `Positions.Locator(pos).traverse` to return the correct tree when + // This is needed for `Positions.Locator(pos).traverse` to return the correct tree when // the `pos` is a point position with all its values equal to `vd.rhs.pos.start`. if(field.pos.isRange && vd.rhs.pos.isRange) field.pos = field.pos.withEnd(vd.rhs.pos.start - 1) field @@ -444,13 +444,23 @@ abstract class TreeGen extends macros.TreeBuilder { def mkFunctionTypeTree(argtpes: List[Tree], restpe: Tree): Tree = AppliedTypeTree(rootScalaDot(newTypeName("Function" + argtpes.length)), argtpes ::: List(restpe)) + /** Create a literal unit tree that is inserted by the compiler but not + * written by end user. It's important to distinguish the two so that + * quasiquotes can strip synthetic ones away. + */ + def mkSyntheticUnit() = Literal(Constant(())).updateAttachment(SyntheticUnitAttachment) + /** Create block of statements `stats` */ def mkBlock(stats: List[Tree]): Tree = if (stats.isEmpty) Literal(Constant(())) - else if (!stats.last.isTerm) Block(stats, Literal(Constant(())).updateAttachment(SyntheticUnitAttachment)) + else if (!stats.last.isTerm) Block(stats, mkSyntheticUnit()) else if (stats.length == 1) stats.head else Block(stats.init, stats.last) + /** Create a block that wraps multiple statements but don't + * do any wrapping if there is just one statement. Used by + * quasiquotes, macro c.parse api and toolbox. + */ def mkTreeOrBlock(stats: List[Tree]) = stats match { case Nil => EmptyTree case head :: Nil => head diff --git a/test/files/scalacheck/quasiquotes/TermConstructionProps.scala b/test/files/scalacheck/quasiquotes/TermConstructionProps.scala index 7bbd7a85b3..54187d68c2 100644 --- a/test/files/scalacheck/quasiquotes/TermConstructionProps.scala +++ b/test/files/scalacheck/quasiquotes/TermConstructionProps.scala @@ -116,7 +116,7 @@ object TermConstructionProps extends QuasiquoteProperties("term construction") { def blockInvariant(quote: Tree, trees: List[Tree]) = quote ≈ (trees match { - case Nil => q"()" + case Nil => q"" case _ :+ last if !last.isTerm => Block(trees, q"()") case head :: Nil => head case init :+ last => Block(init, last) @@ -268,4 +268,27 @@ object TermConstructionProps extends QuasiquoteProperties("term construction") { val t = q"{ a; b }; c; d" assertEqAst(q"f(...$t)", "f(a, b)(c)(d)") } + + property("remove synthetic unit") = test { + val q"{ ..$stats1 }" = q"{ def x = 2 }" + assert(stats1 ≈ List(q"def x = 2")) + val q"{ ..$stats2 }" = q"{ class X }" + assert(stats2 ≈ List(q"class X")) + val q"{ ..$stats3 }" = q"{ type X = Int }" + assert(stats3 ≈ List(q"type X = Int")) + val q"{ ..$stats4 }" = q"{ val x = 2 }" + assert(stats4 ≈ List(q"val x = 2")) + } + + property("don't remove user-defined unit") = test { + val q"{ ..$stats }" = q"{ def x = 2; () }" + assert(stats ≈ List(q"def x = 2", q"()")) + } + + property("empty-tree as block") = test { + val q"{ ..$stats1 }" = q" " + assert(stats1.isEmpty) + val stats2 = List.empty[Tree] + assert(q"{ ..$stats2 }" ≈ q"") + } } -- cgit v1.2.3