diff options
13 files changed, 206 insertions, 64 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala index 9e580d8bc8..137fc79004 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala @@ -33,7 +33,7 @@ trait ParsersCommon extends ScannersCommon { self => import global.{currentUnit => _, _} def newLiteral(const: Any) = Literal(Constant(const)) - def literalUnit = newLiteral(()) + def literalUnit = gen.mkSyntheticUnit() /** This is now an abstract class, only to work around the optimizer: * methods in traits are never inlined. diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala b/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala index 8a54519401..2027d43264 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Holes.scala @@ -31,26 +31,28 @@ trait Holes { self: Quasiquotes => import definitions._ import universeTypes._ - protected lazy val IterableTParam = IterableClass.typeParams(0).asType.toType - protected def inferParamImplicit(tfun: Type, targ: Type) = c.inferImplicitValue(appliedType(tfun, List(targ)), silent = true) - protected def inferLiftable(tpe: Type): Tree = inferParamImplicit(liftableType, tpe) - protected def inferUnliftable(tpe: Type): Tree = inferParamImplicit(unliftableType, tpe) - protected def isLiftableType(tpe: Type) = inferLiftable(tpe) != EmptyTree - protected def isNativeType(tpe: Type) = + private lazy val IterableTParam = IterableClass.typeParams(0).asType.toType + private def inferParamImplicit(tfun: Type, targ: Type) = c.inferImplicitValue(appliedType(tfun, List(targ)), silent = true) + private def inferLiftable(tpe: Type): Tree = inferParamImplicit(liftableType, tpe) + private def inferUnliftable(tpe: Type): Tree = inferParamImplicit(unliftableType, tpe) + private def isLiftableType(tpe: Type) = inferLiftable(tpe) != EmptyTree + private def isNativeType(tpe: Type) = (tpe <:< treeType) || (tpe <:< nameType) || (tpe <:< modsType) || (tpe <:< flagsType) || (tpe <:< symbolType) - protected def isBottomType(tpe: Type) = + private def isBottomType(tpe: Type) = tpe <:< NothingClass.tpe || tpe <:< NullClass.tpe - protected def stripIterable(tpe: Type, limit: Option[Cardinality] = None): (Cardinality, Type) = + private def extractIterableTParam(tpe: Type) = + IterableTParam.asSeenFrom(tpe, IterableClass) + private def stripIterable(tpe: Type, limit: Option[Cardinality] = None): (Cardinality, Type) = if (limit.map { _ == NoDot }.getOrElse { false }) (NoDot, tpe) else if (tpe != null && !isIterableType(tpe)) (NoDot, tpe) else if (isBottomType(tpe)) (NoDot, tpe) else { - val targ = IterableTParam.asSeenFrom(tpe, IterableClass) + val targ = extractIterableTParam(tpe) val (card, innerTpe) = stripIterable(targ, limit.map { _.pred }) (card.succ, innerTpe) } - protected def iterableTypeFromCard(n: Cardinality, tpe: Type): Type = { + private def iterableTypeFromCard(n: Cardinality, tpe: Type): Type = { if (n == NoDot) tpe else appliedType(IterableClass.toType, List(iterableTypeFromCard(n.pred, tpe))) } @@ -74,8 +76,7 @@ trait Holes { self: Quasiquotes => class ApplyHole(card: Cardinality, splicee: Tree) extends Hole { val (strippedTpe, tpe): (Type, Type) = { - if (stripIterable(splicee.tpe)._1.value < card.value) cantSplice() - val (_, strippedTpe) = stripIterable(splicee.tpe, limit = Some(card)) + val (strippedCard, strippedTpe) = stripIterable(splicee.tpe, limit = Some(card)) if (isBottomType(strippedTpe)) cantSplice() else if (isNativeType(strippedTpe)) (strippedTpe, iterableTypeFromCard(card, strippedTpe)) else if (isLiftableType(strippedTpe)) (strippedTpe, iterableTypeFromCard(card, treeType)) @@ -88,14 +89,14 @@ trait Holes { self: Quasiquotes => else if (isLiftableType(itpe)) lifted(itpe)(tree) else global.abort("unreachable") if (card == NoDot) inner(strippedTpe)(splicee) - else iterated(card, strippedTpe, inner(strippedTpe))(splicee) + else iterated(card, splicee, splicee.tpe) } val pos = splicee.pos val cardinality = stripIterable(tpe)._1 - protected def cantSplice(): Nothing = { + private def cantSplice(): Nothing = { val (iterableCard, iterableType) = stripIterable(splicee.tpe) val holeCardMsg = if (card != NoDot) s" with $card" else "" val action = "splice " + splicee.tpe + holeCardMsg @@ -111,28 +112,66 @@ trait Holes { self: Quasiquotes => c.abort(splicee.pos, s"Can't $action, $advice") } - protected def lifted(tpe: Type)(tree: Tree): Tree = { + private def lifted(tpe: Type)(tree: Tree): Tree = { val lifter = inferLiftable(tpe) assert(lifter != EmptyTree, s"couldnt find a liftable for $tpe") val lifted = Apply(lifter, List(tree)) atPos(tree.pos)(lifted) } - protected def iterated(card: Cardinality, tpe: Type, elementTransform: Tree => Tree = identity)(tree: Tree): Tree = { - assert(card != NoDot) - def reifyIterable(tree: Tree, n: Cardinality): Tree = { - def loop(tree: Tree, n: Cardinality): Tree = - if (n == NoDot) elementTransform(tree) - else { - val x: TermName = c.freshName() - val wrapped = reifyIterable(Ident(x), n.pred) - val xToWrapped = Function(List(ValDef(Modifiers(PARAM), x, TypeTree(), EmptyTree)), wrapped) - Select(Apply(Select(tree, nme.map), List(xToWrapped)), nme.toList) - } - if (tree.tpe != null && (tree.tpe <:< listTreeType || tree.tpe <:< listListTreeType)) tree - else atPos(tree.pos)(loop(tree, n)) + private def toStats(tree: Tree): Tree = + // q"$u.build.toStats($tree)" + Apply(Select(Select(u, nme.build), nme.toStats), tree :: Nil) + + private def toList(tree: Tree, tpe: Type): Tree = + if (isListType(tpe)) tree + else Select(tree, nme.toList) + + private def mapF(tree: Tree, f: Tree => Tree): Tree = + if (f(Ident(TermName("x"))) equalsStructure Ident(TermName("x"))) tree + else { + val x: TermName = c.freshName() + // q"$tree.map { $x => ${f(Ident(x))} }" + Apply(Select(tree, nme.map), + Function(ValDef(Modifiers(PARAM), x, TypeTree(), EmptyTree) :: Nil, + f(Ident(x))) :: Nil) } - reifyIterable(tree, card) + + private object IterableType { + def unapply(tpe: Type): Option[Type] = + if (isIterableType(tpe)) Some(extractIterableTParam(tpe)) else None + } + + private object LiftedType { + def unapply(tpe: Type): Option[Tree => Tree] = + if (tpe <:< treeType) Some(t => t) + else if (isLiftableType(tpe)) Some(lifted(tpe)(_)) + else None + } + + /** Map high-cardinality splice onto an expression that eveluates as a list of given cardinality. + * + * All possible combinations of representations are given in the table below: + * + * input output for T <: Tree output for T: Liftable + * + * ..${x: Iterable[T]} x.toList x.toList.map(lift) + * ..${x: T} toStats(x) toStats(lift(x)) + * + * ...${x: Iterable[Iterable[T]]} x.toList { _.toList } x.toList.map { _.toList.map(lift) } + * ...${x: Iterable[T]} x.toList.map { toStats(_) } x.toList.map { toStats(lift(_)) } + * ...${x: T} toStats(x).map { toStats(_) } toStats(lift(x)).map { toStats(_) } + * + * For optimization purposes `x.toList` is represented as just `x` if it is statically known that + * x is not just an Iterable[T] but a List[T]. Similarly no mapping is performed if mapping function is + * known to be an identity. + */ + private def iterated(card: Cardinality, tree: Tree, tpe: Type): Tree = (card, tpe) match { + case (DotDot, tpe @ IterableType(LiftedType(lift))) => mapF(toList(tree, tpe), lift) + case (DotDot, LiftedType(lift)) => toStats(lift(tree)) + case (DotDotDot, tpe @ IterableType(inner)) => mapF(toList(tree, tpe), t => iterated(DotDot, t, inner)) + case (DotDotDot, LiftedType(lift)) => mapF(toStats(lift(tree)), toStats) + case _ => global.abort("unreachable") } } diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala b/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala index ec4ca1c845..c817b5122f 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Parsers.scala @@ -160,15 +160,19 @@ trait Parsers { self: Quasiquotes => } } - object TermParser extends Parser { - def entryPoint = { parser => - parser.templateOrTopStatSeq() match { - case head :: Nil => Block(Nil, head) - case lst => gen.mkTreeOrBlock(lst) - } + /** Wrapper around tree parsed in q"..." quote. Needed to support ..$ splicing on top-level. */ + object Q { + def apply(tree: Tree): Block = Block(Nil, tree).updateAttachment(Q) + def unapply(tree: Tree): Option[Tree] = tree match { + case Block(Nil, contents) if tree.hasAttachment[Q.type] => Some(contents) + case _ => None } } + object TermParser extends Parser { + def entryPoint = parser => Q(gen.mkTreeOrBlock(parser.templateOrTopStatSeq())) + } + object TypeParser extends Parser { def entryPoint = { parser => if (parser.in.token == EOF) diff --git a/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala b/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala index 87ab52414c..45bc2d776c 100644 --- a/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala +++ b/src/compiler/scala/tools/reflect/quasiquotes/Reifiers.scala @@ -185,12 +185,14 @@ trait Reifiers { self: Quasiquotes => reifyBuildCall(nme.SyntacticFunction, args, body) case SyntacticIdent(name, isBackquoted) => reifyBuildCall(nme.SyntacticIdent, name, isBackquoted) - case Block(Nil, Placeholder(Hole(tree, DotDot))) => + case Q(Placeholder(Hole(tree, DotDot))) => mirrorBuildCall(nme.SyntacticBlock, tree) - case Block(Nil, other) => + case Q(other) => reifyTree(other) - case Block(stats, last) => - reifyBuildCall(nme.SyntacticBlock, stats :+ last) + // Syntactic block always matches so we have to be careful + // not to cause infinite recursion. + case block @ SyntacticBlock(stats) if block.isInstanceOf[Block] => + reifyBuildCall(nme.SyntacticBlock, stats) case Try(block, catches, finalizer) => reifyBuildCall(nme.SyntacticTry, block, catches, finalizer) case Match(selector, cases) => diff --git a/src/reflect/scala/reflect/api/BuildUtils.scala b/src/reflect/scala/reflect/api/BuildUtils.scala index 260974a981..3a69390bcf 100644 --- a/src/reflect/scala/reflect/api/BuildUtils.scala +++ b/src/reflect/scala/reflect/api/BuildUtils.scala @@ -72,6 +72,8 @@ private[reflect] trait BuildUtils { self: Universe => def setSymbol[T <: Tree](tree: T, sym: Symbol): T + def toStats(tree: Tree): List[Tree] + def mkAnnotation(tree: Tree): Tree def mkAnnotation(trees: List[Tree]): List[Tree] diff --git a/src/reflect/scala/reflect/internal/BuildUtils.scala b/src/reflect/scala/reflect/internal/BuildUtils.scala index 3061885549..16bb3e5989 100644 --- a/src/reflect/scala/reflect/internal/BuildUtils.scala +++ b/src/reflect/scala/reflect/internal/BuildUtils.scala @@ -61,6 +61,8 @@ trait BuildUtils { self: SymbolTable => def setSymbol[T <: Tree](tree: T, sym: Symbol): T = { tree.setSymbol(sym); tree } + def toStats(tree: Tree): List[Tree] = SyntacticBlock.unapply(tree).get + def mkAnnotation(tree: Tree): Tree = tree match { case SyntacticNew(Nil, SyntacticApplied(SyntacticTypeApplied(_, _), _) :: Nil, noSelfType, Nil) => tree @@ -381,13 +383,41 @@ trait BuildUtils { self: SymbolTable => } } + object SyntheticUnit { + def unapply(tree: Tree): Boolean = tree match { + case Literal(Constant(())) if tree.hasAttachment[SyntheticUnitAttachment.type] => true + case _ => false + } + } + + /** Syntactic combinator that abstracts over Block tree. + * + * Apart from providing a more straightforward api that exposes + * block as a list of elements rather than (stats, expr) pair + * it also: + * + * 1. Treats of q"" (empty tree) as zero-element block. + * + * 2. Strips trailing synthetic units which are inserted by the + * compiler if the block ends with a definition rather + * than an expression. + * + * 3. Matches non-block term trees and recognizes them as + * single-element blocks for sake of consistency with + * compiler's default to treat single-element blocks with + * expressions as just expressions. + */ object SyntacticBlock extends SyntacticBlockExtractor { - def apply(stats: List[Tree]): Tree = gen.mkBlock(stats) + def apply(stats: List[Tree]): Tree = + if (stats.isEmpty) EmptyTree + else gen.mkBlock(stats) def unapply(tree: Tree): Option[List[Tree]] = tree match { - case self.Block(stats, expr) => Some(stats :+ expr) - case _ if tree.isTerm => Some(tree :: Nil) - case _ => None + case self.Block(stats, SyntheticUnit()) => Some(stats) + case self.Block(stats, expr) => Some(stats :+ expr) + case EmptyTree => Some(Nil) + case _ if tree.isTerm => Some(tree :: Nil) + case _ => None } } diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index 5b06239863..4d24f0b219 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -625,6 +625,7 @@ trait Definitions extends api.StandardDefinitions { def isBlackboxMacroBundleType(tp: Type) = isMacroBundleType(tp) && (macroBundleParamInfo(tp) <:< BlackboxContextClass.tpe) + def isListType(tp: Type) = tp <:< classExistentialType(ListClass) def isIterableType(tp: Type) = tp <:< classExistentialType(IterableClass) // These "direct" calls perform no dealiasing. They are most needed when diff --git a/src/reflect/scala/reflect/internal/StdAttachments.scala b/src/reflect/scala/reflect/internal/StdAttachments.scala index 09fd996f39..139a79ffe1 100644 --- a/src/reflect/scala/reflect/internal/StdAttachments.scala +++ b/src/reflect/scala/reflect/internal/StdAttachments.scala @@ -36,6 +36,10 @@ trait StdAttachments { */ case object ForAttachment extends PlainAttachment + /** Identifies unit constants which were inserted by the compiler (e.g. gen.mkBlock) + */ + case object SyntheticUnitAttachment extends PlainAttachment + /** Untyped list of subpatterns attached to selector dummy. */ case class SubpatternsAttachment(patterns: List[Tree]) } diff --git a/src/reflect/scala/reflect/internal/StdNames.scala b/src/reflect/scala/reflect/internal/StdNames.scala index 7015105261..256d5759fa 100644 --- a/src/reflect/scala/reflect/internal/StdNames.scala +++ b/src/reflect/scala/reflect/internal/StdNames.scala @@ -752,6 +752,7 @@ trait StdNames { val toArray: NameType = "toArray" val toList: NameType = "toList" val toObjectArray : NameType = "toObjectArray" + val toStats: NameType = "toStats" val TopScope: NameType = "TopScope" val toString_ : NameType = "toString" val toTypeConstructor: NameType = "toTypeConstructor" diff --git a/src/reflect/scala/reflect/internal/TreeGen.scala b/src/reflect/scala/reflect/internal/TreeGen.scala index b16cbd8325..e602a12175 100644 --- a/src/reflect/scala/reflect/internal/TreeGen.scala +++ b/src/reflect/scala/reflect/internal/TreeGen.scala @@ -342,7 +342,7 @@ abstract class TreeGen extends macros.TreeBuilder { } param } - + val (edefs, rest) = body span treeInfo.isEarlyDef val (evdefs, etdefs) = edefs partition treeInfo.isEarlyValDef val gvdefs = evdefs map { @@ -381,11 +381,11 @@ abstract class TreeGen extends macros.TreeBuilder { } constr foreach (ensureNonOverlapping(_, parents ::: gvdefs, focus = false)) // Field definitions for the class - remove defaults. - + val fieldDefs = vparamss.flatten map (vd => { val field = copyValDef(vd)(mods = vd.mods &~ DEFAULTPARAM, rhs = EmptyTree) // Prevent overlapping of `field` end's position with default argument's start position. - // This is needed for `Positions.Locator(pos).traverse` to return the correct tree when + // This is needed for `Positions.Locator(pos).traverse` to return the correct tree when // the `pos` is a point position with all its values equal to `vd.rhs.pos.start`. if(field.pos.isRange && vd.rhs.pos.isRange) field.pos = field.pos.withEnd(vd.rhs.pos.start - 1) field @@ -444,13 +444,23 @@ abstract class TreeGen extends macros.TreeBuilder { def mkFunctionTypeTree(argtpes: List[Tree], restpe: Tree): Tree = AppliedTypeTree(rootScalaDot(newTypeName("Function" + argtpes.length)), argtpes ::: List(restpe)) + /** Create a literal unit tree that is inserted by the compiler but not + * written by end user. It's important to distinguish the two so that + * quasiquotes can strip synthetic ones away. + */ + def mkSyntheticUnit() = Literal(Constant(())).updateAttachment(SyntheticUnitAttachment) + /** Create block of statements `stats` */ def mkBlock(stats: List[Tree]): Tree = if (stats.isEmpty) Literal(Constant(())) - else if (!stats.last.isTerm) Block(stats, Literal(Constant(()))) + else if (!stats.last.isTerm) Block(stats, mkSyntheticUnit()) else if (stats.length == 1) stats.head else Block(stats.init, stats.last) + /** Create a block that wraps multiple statements but don't + * do any wrapping if there is just one statement. Used by + * quasiquotes, macro c.parse api and toolbox. + */ def mkTreeOrBlock(stats: List[Tree]) = stats match { case Nil => EmptyTree case head :: Nil => head diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala index b9b171c7ed..8811b5513e 100644 --- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala +++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala @@ -58,6 +58,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => this.CompoundTypeTreeOriginalAttachment this.BackquotedIdentifierAttachment this.ForAttachment + this.SyntheticUnitAttachment this.SubpatternsAttachment this.noPrint this.typeDebug diff --git a/test/files/scalacheck/quasiquotes/ErrorProps.scala b/test/files/scalacheck/quasiquotes/ErrorProps.scala index 92d299bede..3a66574c7d 100644 --- a/test/files/scalacheck/quasiquotes/ErrorProps.scala +++ b/test/files/scalacheck/quasiquotes/ErrorProps.scala @@ -52,13 +52,6 @@ object ErrorProps extends QuasiquoteProperties("errors") { StringContext("\"", "\"").q(x) """) - property("expected different cardinality") = fails( - "Can't splice List[reflect.runtime.universe.Tree] with ..., consider using ..", - """ - val args: List[Tree] = Nil - q"f(...$args)" - """) - property("non-liftable type ..") = fails( "Can't splice List[StringBuilder] with .., consider omitting the dots or providing an implicit instance of Liftable[StringBuilder]", """ @@ -90,13 +83,6 @@ object ErrorProps extends QuasiquoteProperties("errors") { q"$xs" """) - property("use zero card") = fails( - "Can't splice reflect.runtime.universe.Tree with .., consider omitting the dots", - """ - val t = EmptyTree - q"f(..$t)" - """) - property("not liftable or natively supported") = fails( "Can't splice StringBuilder, consider providing an implicit instance of Liftable[StringBuilder]", """ @@ -188,4 +174,4 @@ object ErrorProps extends QuasiquoteProperties("errors") { // // Make sure a nice error is reported in this case // { import Flag._; val mods = NoMods; q"lazy $mods val x: Int" } -}
\ No newline at end of file +} diff --git a/test/files/scalacheck/quasiquotes/TermConstructionProps.scala b/test/files/scalacheck/quasiquotes/TermConstructionProps.scala index 38fbfa9f7f..54187d68c2 100644 --- a/test/files/scalacheck/quasiquotes/TermConstructionProps.scala +++ b/test/files/scalacheck/quasiquotes/TermConstructionProps.scala @@ -116,7 +116,7 @@ object TermConstructionProps extends QuasiquoteProperties("term construction") { def blockInvariant(quote: Tree, trees: List[Tree]) = quote ≈ (trees match { - case Nil => q"()" + case Nil => q"" case _ :+ last if !last.isTerm => Block(trees, q"()") case head :: Nil => head case init :+ last => Block(init, last) @@ -229,4 +229,66 @@ object TermConstructionProps extends QuasiquoteProperties("term construction") { val q"($a, $b) => $_" = q"_ + _" assert(a.name != b.name) } + + property("SI-7275 a") = test { + val t = q"stat1; stat2" + assertEqAst(q"..$t", "{stat1; stat2}") + } + + property("SI-7275 b") = test { + def f(t: Tree) = q"..$t" + assertEqAst(f(q"stat1; stat2"), "{stat1; stat2}") + } + + property("SI-7275 c1") = test { + object O + implicit val liftO = Liftable[O.type] { _ => q"foo; bar" } + assertEqAst(q"f(..$O)", "f(foo, bar)") + } + + property("SI-7275 c2") = test { + object O + implicit val liftO = Liftable[O.type] { _ => q"{ foo; bar }; { baz; bax }" } + assertEqAst(q"f(...$O)", "f(foo, bar)(baz, bax)") + } + + property("SI-7275 d") = test { + val l = q"a; b" :: q"c; d" :: Nil + assertEqAst(q"f(...$l)", "f(a, b)(c, d)") + val l2: Iterable[Tree] = l + assertEqAst(q"f(...$l2)", "f(a, b)(c, d)") + } + + property("SI-7275 e") = test { + val t = q"{ a; b }; { c; d }" + assertEqAst(q"f(...$t)", "f(a, b)(c, d)") + } + + property("SI-7275 e2") = test { + val t = q"{ a; b }; c; d" + assertEqAst(q"f(...$t)", "f(a, b)(c)(d)") + } + + property("remove synthetic unit") = test { + val q"{ ..$stats1 }" = q"{ def x = 2 }" + assert(stats1 ≈ List(q"def x = 2")) + val q"{ ..$stats2 }" = q"{ class X }" + assert(stats2 ≈ List(q"class X")) + val q"{ ..$stats3 }" = q"{ type X = Int }" + assert(stats3 ≈ List(q"type X = Int")) + val q"{ ..$stats4 }" = q"{ val x = 2 }" + assert(stats4 ≈ List(q"val x = 2")) + } + + property("don't remove user-defined unit") = test { + val q"{ ..$stats }" = q"{ def x = 2; () }" + assert(stats ≈ List(q"def x = 2", q"()")) + } + + property("empty-tree as block") = test { + val q"{ ..$stats1 }" = q" " + assert(stats1.isEmpty) + val stats2 = List.empty[Tree] + assert(q"{ ..$stats2 }" ≈ q"") + } } |