aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-12-10 11:44:39 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-12-10 11:44:39 +0100
commit8b7e520b7d66abe14560508a24fe88d99fbedd9e (patch)
tree0b3bc74cbe02c6603d845c81fae53e4f69815eaf
parent7c93a9e0e288b55027646016913c7368732d54e4 (diff)
downloadscala-async-8b7e520b7d66abe14560508a24fe88d99fbedd9e.tar.gz
scala-async-8b7e520b7d66abe14560508a24fe88d99fbedd9e.tar.bz2
scala-async-8b7e520b7d66abe14560508a24fe88d99fbedd9e.zip
Workaround non-idempotency of typing pattern matching anonymous functions.
- Undo the transformation that takes place in Typers to leave us with Match(EmptyTree, cases). - Make sure we don't descend into the cases of such a tree when peforming the async transform
-rw-r--r--src/main/scala/scala/async/Async.scala3
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala24
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala56
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala20
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala10
-rw-r--r--src/test/scala/scala/async/run/toughtype/ToughType.scala31
6 files changed, 114 insertions, 30 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 1961e69..2ff0f07 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -78,7 +78,8 @@ abstract class AsyncBase {
// - if/match only used in statement position.
val anfTree: Block = {
val anf = AnfTransform[c.type](c)
- val stats1 :+ expr1 = anf(body.tree)
+ val restored = utils.restorePatternMatchingFunctions(body.tree)
+ val stats1 :+ expr1 = anf(restored)
val block = Block(stats1, expr1)
c.typeCheck(block).asInstanceOf[Block]
}
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 8bb5bcd..f62d7a1 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -67,6 +67,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
reportUnsupportedAwait(function, "nested function")
}
+ override def patMatFunction(tree: Match) {
+ reportUnsupportedAwait(tree, "nested function")
+ }
+
override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
tree match {
@@ -107,19 +111,19 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def nestedMethod(defDef: DefDef) {
nestedMethodsToLift += defDef
- defDef.rhs foreach {
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) =>
- valDefsToLift += vd // lift all vals referred to by nested methods.
- case _ =>
- }
- case _ =>
- }
+ markReferencedVals(defDef)
}
override def function(function: Function) {
- function foreach {
+ markReferencedVals(function)
+ }
+
+ override def patMatFunction(tree: Match) {
+ markReferencedVals(tree)
+ }
+
+ private def markReferencedVals(tree: Tree) {
+ tree foreach {
case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) =>
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 23f39d2..22c1b90 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -25,12 +25,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val stateMachine = newTermName(fresh("stateMachine"))
val stateMachineT = stateMachine.toTypeName
val apply = newTermName("apply")
+ val applyOrElse = newTermName("applyOrElse")
val tr = newTermName("tr")
val matchRes = "matchres"
val ifRes = "ifres"
val await = "await"
val bindSuffix = "$bind"
- def arg(i: Int) = "arg" + i
+
+ def arg(i: Int) = "arg" + i
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
@@ -64,7 +66,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
/** Descends into the regions of the tree that are subject to the
* translation to a state machine by `async`. When a nested template,
- * function, or by-name argument is encountered, the descend stops,
+ * function, or by-name argument is encountered, the descent stops,
* and `nestedClass` etc are invoked.
*/
trait AsyncTraverser extends Traverser {
@@ -83,20 +85,24 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
def function(function: Function) {
}
+ def patMatFunction(tree: Match) {
+ }
+
override def traverse(tree: Tree) {
tree match {
- case cd: ClassDef => nestedClass(cd)
- case md: ModuleDef => nestedModule(md)
- case dd: DefDef => nestedMethod(dd)
- case fun: Function => function(fun)
- case Apply(fun, args) =>
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case dd: DefDef => nestedMethod(dd)
+ case fun: Function => function(fun)
+ case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
+ case Apply(fun, args) =>
val isInByName = isByName(fun)
for ((arg, index) <- args.zipWithIndex) {
if (!isInByName(index)) traverse(arg)
else byNameArgument(arg)
}
traverse(fun)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
}
@@ -245,6 +251,40 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
}
}
+ /**
+ * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }`
+ * with `Match(EmptyTree, cases`.
+ *
+ * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees.
+ */
+ // TODO Reference JIRA issue.
+ final def restorePatternMatchingFunctions(tree: Tree) =
+ RestorePatternMatchingFunctions transform tree
+
+ private object RestorePatternMatchingFunctions extends Transformer {
+
+ import language.existentials
+
+ override def transform(tree: Tree): Tree = {
+ val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet]
+ def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC
+
+ tree match {
+ case Block(
+ (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil,
+ Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) =>
+ val restored = (body collectFirst {
+ case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) =>
+ val transformedCases = super.transformStats(cases, currentOwner).asInstanceOf[List[CaseDef]]
+ Match(EmptyTree, transformedCases)
+ }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body"))
+ restored
+ case t => super.transform(t)
+ }
+ }
+ }
+
+
def isSafeToInline(tree: Tree) = {
val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
object treeInfo extends {
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index b22faa9..93cfdf5 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -62,23 +62,21 @@ object TreeInterrogation extends App {
val levels = Seq("trace", "debug")
def setAll(value: Boolean) = levels.foreach(set(_, value))
- setAll(true)
- try t finally setAll(false)
+ setAll(value = true)
+ try t finally setAll(value = false)
}
withDebug {
val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
+ val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten")
val tree = tb.parse(
""" import scala.async.AsyncId.{async, await}
- | def foo(a: Int, b: Int) = (a, b)
- | val result = async {
- | var i = 0
- | def next() = {
- | i += 1;
- | i
- | }
- | foo(next(), await(next()))
+ | async {
+ | await(1)
+ | val neg1 = -1
+ | val a = await(1)
+ | val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
+ | await(f(2))
| }
| ()
| """.stripMargin)
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index ecc84f9..9974a07 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -93,6 +93,16 @@ class NakedAwait {
}
@Test
+ def nestedPatMatFunction() {
+ expectError("await must not be used under a nested class.") { // TODO more specific error message
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { { case x => { await(false) } } : PartialFunction[Any, Any] }
+ """.stripMargin
+ }
+ }
+
+ @Test
def tryBody() {
expectError("await must not be used under a try/catch.") {
"""
diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala
index 9cfc1ca..83f5a2d 100644
--- a/src/test/scala/scala/async/run/toughtype/ToughType.scala
+++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala
@@ -36,4 +36,35 @@ class ToughTypeSpec {
val res: (List[_], scala.async.run.toughtype.ToughTypeObject.Inner) = Await.result(fut, 2 seconds)
res._1 mustBe (Nil)
}
+
+ @Test def patternMatchingPartialFunction() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val a = await(1)
+ val f = { case x => x + a }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe 3
+ }
+
+ @Test def patternMatchingPartialFunctionNested() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val neg1 = -1
+ val a = await(1)
+ val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe -3
+ }
+
+ @Test def patternMatchingFunction() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val a = await(1)
+ val f = { case x => x + a }: Function[Int, Int]
+ await(f(2))
+ } mustBe 3
+ }
}