aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2013-04-10 17:24:09 +0200
committerJason Zaugg <jzaugg@gmail.com>2013-04-10 18:03:24 +0200
commitb60346cda87a4b2213d9b779ed4c6241f126647d (patch)
tree25d02b3bcca49879b37265c6f71d63337adbc0e1
parentd332b4614dfd07b4fd91576dac3ff3fddd657106 (diff)
downloadscala-async-b60346cda87a4b2213d9b779ed4c6241f126647d.tar.gz
scala-async-b60346cda87a4b2213d9b779ed4c6241f126647d.tar.bz2
scala-async-b60346cda87a4b2213d9b779ed4c6241f126647d.zip
Scala 2.10.1 compat: account for change in PartialFunction synthesis.
Since SI-6187, the default case of a partial function is now included in the tree. Before, it was a tree attachment, conditionally inserted in the pattern matcher. I had hoped that that change would allow us to do away with `RestorePatternMatchingFunctions` altogether, but it seems that we aren't so lucky. Instead, I've adapted that transformer to account for the new scheme.
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala10
-rw-r--r--src/test/scala/scala/async/run/toughtype/ToughType.scala40
2 files changed, 28 insertions, 22 deletions
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 434d6fd..38c33a4 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -274,18 +274,26 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
private object RestorePatternMatchingFunctions extends Transformer {
import language.existentials
+ val DefaultCaseName: TermName = "defaultCase$"
override def transform(tree: Tree): Tree = {
val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet]
def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC
+ /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know
+ * whether the user provided cases are exhaustive. */
+ def isSyntheticDefaultCase(cdef: CaseDef) = cdef match {
+ case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true
+ case _ => false
+ }
tree match {
case Block(
(cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil,
Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) =>
val restored = (body collectFirst {
case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) =>
- val transformedCases = super.transformStats(cases, currentOwner).asInstanceOf[List[CaseDef]]
+ val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef))
+ val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]]
Match(EmptyTree, transformedCases)
}).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body"))
restored
diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala
index ab0b60d..83f5a2d 100644
--- a/src/test/scala/scala/async/run/toughtype/ToughType.scala
+++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala
@@ -37,28 +37,26 @@ class ToughTypeSpec {
res._1 mustBe (Nil)
}
-// TODO 2.10.1
-// @Test def patternMatchingPartialFunction() {
-// import AsyncId.{await, async}
-// async {
-// await(1)
-// val a = await(1)
-// val f = { case x => x + a }: PartialFunction[Int, Int]
-// await(f(2))
-// } mustBe 3
-// }
+ @Test def patternMatchingPartialFunction() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val a = await(1)
+ val f = { case x => x + a }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe 3
+ }
-// TODO 2.10.1
-// @Test def patternMatchingPartialFunctionNested() {
-// import AsyncId.{await, async}
-// async {
-// await(1)
-// val neg1 = -1
-// val a = await(1)
-// val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
-// await(f(2))
-// } mustBe -3
-// }
+ @Test def patternMatchingPartialFunctionNested() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val neg1 = -1
+ val a = await(1)
+ val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe -3
+ }
@Test def patternMatchingFunction() {
import AsyncId.{await, async}