aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala10
-rw-r--r--src/test/scala/scala/async/run/toughtype/ToughType.scala40
2 files changed, 28 insertions, 22 deletions
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 434d6fd..38c33a4 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -274,18 +274,26 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
private object RestorePatternMatchingFunctions extends Transformer {
import language.existentials
+ val DefaultCaseName: TermName = "defaultCase$"
override def transform(tree: Tree): Tree = {
val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet]
def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC
+ /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know
+ * whether the user provided cases are exhaustive. */
+ def isSyntheticDefaultCase(cdef: CaseDef) = cdef match {
+ case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true
+ case _ => false
+ }
tree match {
case Block(
(cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil,
Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) =>
val restored = (body collectFirst {
case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) =>
- val transformedCases = super.transformStats(cases, currentOwner).asInstanceOf[List[CaseDef]]
+ val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef))
+ val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]]
Match(EmptyTree, transformedCases)
}).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body"))
restored
diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala
index ab0b60d..83f5a2d 100644
--- a/src/test/scala/scala/async/run/toughtype/ToughType.scala
+++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala
@@ -37,28 +37,26 @@ class ToughTypeSpec {
res._1 mustBe (Nil)
}
-// TODO 2.10.1
-// @Test def patternMatchingPartialFunction() {
-// import AsyncId.{await, async}
-// async {
-// await(1)
-// val a = await(1)
-// val f = { case x => x + a }: PartialFunction[Int, Int]
-// await(f(2))
-// } mustBe 3
-// }
+ @Test def patternMatchingPartialFunction() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val a = await(1)
+ val f = { case x => x + a }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe 3
+ }
-// TODO 2.10.1
-// @Test def patternMatchingPartialFunctionNested() {
-// import AsyncId.{await, async}
-// async {
-// await(1)
-// val neg1 = -1
-// val a = await(1)
-// val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
-// await(f(2))
-// } mustBe -3
-// }
+ @Test def patternMatchingPartialFunctionNested() {
+ import AsyncId.{await, async}
+ async {
+ await(1)
+ val neg1 = -1
+ val a = await(1)
+ val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int]
+ await(f(2))
+ } mustBe -3
+ }
@Test def patternMatchingFunction() {
import AsyncId.{await, async}