aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/dotc/Compiler.scala3
-rw-r--r--src/dotty/tools/dotc/transform/PatternMatcher.scala76
-rw-r--r--src/dotty/tools/dotc/transform/TryCatchPatterns.scala99
-rw-r--r--tests/neg/tryPatternMatchError.scala35
-rw-r--r--tests/run/tryPatternMatch.check20
-rw-r--r--tests/run/tryPatternMatch.scala139
6 files changed, 297 insertions, 75 deletions
diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala
index 3844f42a7..ce9280d82 100644
--- a/src/dotty/tools/dotc/Compiler.scala
+++ b/src/dotty/tools/dotc/Compiler.scala
@@ -57,7 +57,8 @@ class Compiler {
new TailRec, // Rewrite tail recursion to loops
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
new ClassOf), // Expand `Predef.classOf` calls.
- List(new PatternMatcher, // Compile pattern matches
+ List(new TryCatchPatterns, // Compile cases in try/catch
+ new PatternMatcher, // Compile pattern matches
new ExplicitOuter, // Add accessors to outer classes from nested ones.
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
new CrossCastAnd, // Normalize selections involving intersection types.
diff --git a/src/dotty/tools/dotc/transform/PatternMatcher.scala b/src/dotty/tools/dotc/transform/PatternMatcher.scala
index fd89696a8..974053769 100644
--- a/src/dotty/tools/dotc/transform/PatternMatcher.scala
+++ b/src/dotty/tools/dotc/transform/PatternMatcher.scala
@@ -1,6 +1,8 @@
package dotty.tools.dotc
package transform
+import scala.language.postfixOps
+
import TreeTransforms._
import core.Denotations._
import core.SymDenotations._
@@ -53,19 +55,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans
translated.ensureConforms(tree.tpe)
}
-
- override def transformTry(tree: tpd.Try)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
- val selector =
- ctx.newSymbol(ctx.owner, ctx.freshName("ex").toTermName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = tree.pos)
- val sel = Ident(selector.termRef).withPos(tree.pos)
- val rethrow = tpd.CaseDef(EmptyTree, EmptyTree, Throw(ref(selector)))
- val newCases = tpd.CaseDef(
- Bind(selector, Underscore(selector.info).withPos(tree.pos)),
- EmptyTree,
- transformMatch(tpd.Match(sel, tree.cases ::: rethrow :: Nil)))
- cpy.Try(tree)(tree.expr, newCases :: Nil, tree.finalizer)
- }
-
class Translator(implicit ctx: Context) {
def translator = {
@@ -1264,27 +1253,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans
t
}
- /** Is this pattern node a catch-all or type-test pattern? */
- def isCatchCase(cdef: CaseDef) = cdef match {
- case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) =>
- isSimpleThrowable(tpt.tpe)
- case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) =>
- isSimpleThrowable(tpt.tpe)
- case _ =>
- isDefaultCase(cdef)
- }
-
- private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match {
- case tp @ TypeRef(pre, _) =>
- val sym = tp.symbol
- (pre == NoPrefix || pre.widen.typeSymbol.isStatic) &&
- (sym.derivesFrom(defn.ThrowableClass)) && /* bq */ !(sym is Flags.Trait)
- case _ =>
- false
- }
-
-
-
/** Implement a pattern match by turning its cases (including the implicit failure case)
* into the corresponding (monadic) extractors, and combining them with the `orElse` combinator.
*
@@ -1335,46 +1303,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans
Block(List(ValDef(selectorSym, sel)), combined)
}
- // return list of typed CaseDefs that are supported by the backend (typed/bind/wildcard)
- // we don't have a global scrutinee -- the caught exception must be bound in each of the casedefs
- // there's no need to check the scrutinee for null -- "throw null" becomes "throw new NullPointerException"
- // try to simplify to a type-based switch, or fall back to a catch-all case that runs a normal pattern match
- // unlike translateMatch, we type our result before returning it
- /*def translateTry(caseDefs: List[CaseDef], pt: Type, pos: Position): List[CaseDef] =
- // if they're already simple enough to be handled by the back-end, we're done
- if (caseDefs forall isCatchCase) caseDefs
- else {
- val swatches = { // switch-catches
- val bindersAndCases = caseDefs map { caseDef =>
- // generate a fresh symbol for each case, hoping we'll end up emitting a type-switch (we don't have a global scrut there)
- // if we fail to emit a fine-grained switch, have to do translateCase again with a single scrutSym (TODO: uniformize substitution on treemakers so we can avoid this)
- val caseScrutSym = freshSym(pos, pureType(defn.ThrowableType))
- (caseScrutSym, propagateSubstitution(translateCase(caseScrutSym, pt)(caseDef), EmptySubstitution))
- }
-
- for(cases <- emitTypeSwitch(bindersAndCases, pt).toList
- if cases forall isCatchCase; // must check again, since it's not guaranteed -- TODO: can we eliminate this? e.g., a type test could test for a trait or a non-trivial prefix, which are not handled by the back-end
- cse <- cases) yield /*fixerUpper(matchOwner, pos)*/(cse).asInstanceOf[CaseDef]
- }
-
- val catches = if (swatches.nonEmpty) swatches else {
- val scrutSym = freshSym(pos, pureType(defn.ThrowableType))
- val casesNoSubstOnly = caseDefs map { caseDef => (propagateSubstitution(translateCase(scrutSym, pt)(caseDef), EmptySubstitution))}
-
- val exSym = freshSym(pos, pureType(defn.ThrowableType), "ex")
-
- List(
- CaseDef(
- Bind(exSym, Ident(??? /*nme.WILDCARD*/)), // TODO: does this need fixing upping?
- EmptyTree,
- combineCasesNoSubstOnly(ref(exSym), scrutSym, casesNoSubstOnly, pt, matchOwner, Some((scrut: Symbol) => Throw(ref(exSym))))
- )
- )
- }
-
- /*typer.typedCases(*/catches/*, defn.ThrowableType, WildcardType)*/
- }*/
-
/** The translation of `pat if guard => body` has two aspects:
* 1) the substitution due to the variables bound by patterns
* 2) the combination of the extractor calls using `flatMap`.
diff --git a/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala
new file mode 100644
index 000000000..9a6ecef51
--- /dev/null
+++ b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala
@@ -0,0 +1,99 @@
+package dotty.tools.dotc
+package transform
+
+import core.Symbols._
+import core.StdNames._
+import ast.Trees._
+import core.Types._
+import dotty.tools.dotc.core.Decorators._
+import dotty.tools.dotc.core.Flags
+import dotty.tools.dotc.core.Contexts.Context
+import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo}
+import dotty.tools.dotc.util.Positions.Position
+
+/** Compiles the cases that can not be handled by primitive catch cases as a common pattern match.
+ *
+ * The following code:
+ * ```
+ * try { <code> }
+ * catch {
+ * <tryCases> // Cases that can be handled by catch
+ * <patternMatchCases> // Cases starting with first one that can't be handled by catch
+ * }
+ * ```
+ * will become:
+ * ```
+ * try { <code> }
+ * catch {
+ * <tryCases>
+ * case e => e match {
+ * <patternMatchCases>
+ * }
+ * }
+ * ```
+ *
+ * Cases that are not supported include:
+ * - Applies and unapplies
+ * - Idents
+ * - Alternatives
+ * - `case _: T =>` where `T` is not `Throwable`
+ *
+ */
+class TryCatchPatterns extends MiniPhaseTransform {
+ import dotty.tools.dotc.ast.tpd._
+
+ def phaseName: String = "tryCatchPatterns"
+
+ override def runsAfter = Set(classOf[ElimRepeated])
+
+ override def checkPostCondition(tree: Tree)(implicit ctx: Context): Unit = tree match {
+ case Try(_, cases, _) =>
+ cases.foreach {
+ case CaseDef(Typed(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.")
+ case CaseDef(Bind(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.")
+ case c =>
+ assert(isDefaultCase(c), "Pattern in Try should be Bind, Typed or default case.")
+ }
+ case _ =>
+ }
+
+ override def transformTry(tree: Try)(implicit ctx: Context, info: TransformerInfo): Tree = {
+ val (tryCases, patternMatchCases) = tree.cases.span(isCatchCase)
+ val fallbackCase = mkFallbackPatterMatchCase(patternMatchCases, tree.pos)
+ cpy.Try(tree)(cases = tryCases ++ fallbackCase)
+ }
+
+ /** Is this pattern node a catch-all or type-test pattern? */
+ private def isCatchCase(cdef: CaseDef)(implicit ctx: Context): Boolean = cdef match {
+ case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) => isSimpleThrowable(tpt.tpe)
+ case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) => isSimpleThrowable(tpt.tpe)
+ case _ => isDefaultCase(cdef)
+ }
+
+ private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match {
+ case tp @ TypeRef(pre, _) =>
+ (pre == NoPrefix || pre.widen.typeSymbol.isStatic) && // Does not require outer class check
+ !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM
+ tp.derivesFrom(defn.ThrowableClass)
+ case _ =>
+ false
+ }
+
+ private def mkFallbackPatterMatchCase(patternMatchCases: List[CaseDef], pos: Position)(
+ implicit ctx: Context, info: TransformerInfo): Option[CaseDef] = {
+ if (patternMatchCases.isEmpty) None
+ else {
+ val exName = ctx.freshName("ex").toTermName
+ val fallbackSelector =
+ ctx.newSymbol(ctx.owner, exName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = pos)
+ val sel = Ident(fallbackSelector.termRef).withPos(pos)
+ val rethrow = CaseDef(EmptyTree, EmptyTree, Throw(ref(fallbackSelector)))
+ Some(CaseDef(
+ Bind(fallbackSelector, Underscore(fallbackSelector.info).withPos(pos)),
+ EmptyTree,
+ transformFollowing(Match(sel, patternMatchCases ::: rethrow :: Nil)))
+ )
+ }
+ }
+
+}
diff --git a/tests/neg/tryPatternMatchError.scala b/tests/neg/tryPatternMatchError.scala
new file mode 100644
index 000000000..fe12a6232
--- /dev/null
+++ b/tests/neg/tryPatternMatchError.scala
@@ -0,0 +1,35 @@
+import java.io.IOException
+import java.lang.NullPointerException
+import java.lang.IllegalArgumentException
+
+object IAE {
+ def unapply(e: Exception): Option[String] =
+ if (e.isInstanceOf[IllegalArgumentException]) Some(e.getMessage)
+ else None
+}
+
+object EX extends Exception
+
+trait ExceptionTrait extends Exception
+
+object Test {
+ def main(args: Array[String]): Unit = {
+ var a: Int = 1
+ try {
+ throw new IllegalArgumentException()
+ } catch {
+ case e: IOException if e.getMessage == null =>
+ case e: NullPointerException =>
+ case e: IndexOutOfBoundsException =>
+ case _: NoSuchElementException =>
+ case _: ExceptionTrait =>
+ case _: NoSuchElementException if a <= 1 =>
+ case _: NullPointerException | _:IOException =>
+ case `a` => // This case should probably emmit an error
+ case e: Int => // error
+ case EX =>
+ case IAE(msg) =>
+ case e: IllegalArgumentException =>
+ }
+ }
+}
diff --git a/tests/run/tryPatternMatch.check b/tests/run/tryPatternMatch.check
new file mode 100644
index 000000000..44f7b7d5a
--- /dev/null
+++ b/tests/run/tryPatternMatch.check
@@ -0,0 +1,20 @@
+success 1
+success 2
+success 3
+success 4
+success 5
+success 6
+success 7
+success 8
+success 9.1
+success 9.2
+IllegalArgumentException: abc
+IllegalArgumentException
+NullPointerException | IOException
+NoSuchElementException
+EX
+InnerException
+NullPointerException
+ExceptionTrait
+ClassCastException
+TimeoutException escaped
diff --git a/tests/run/tryPatternMatch.scala b/tests/run/tryPatternMatch.scala
new file mode 100644
index 000000000..06b469d4d
--- /dev/null
+++ b/tests/run/tryPatternMatch.scala
@@ -0,0 +1,139 @@
+import java.io.IOException
+import java.util.concurrent.TimeoutException
+
+object IAE {
+ def unapply(e: Exception): Option[String] =
+ if (e.isInstanceOf[IllegalArgumentException] && e.getMessage != null) Some(e.getMessage)
+ else None
+}
+
+object EX extends Exception {
+ val msg = "a"
+ class InnerException extends Exception(msg)
+}
+
+trait ExceptionTrait extends Exception
+
+trait TestTrait {
+ type ExceptionType <: Exception
+
+ def traitTest(): Unit = {
+ try {
+ throw new IOException
+ } catch {
+ case _: ExceptionType => println("success 9.2")
+ case _ => println("failed 9.2")
+ }
+ }
+}
+
+object Test extends TestTrait {
+ type ExceptionType = IOException
+
+ def main(args: Array[String]): Unit = {
+ var a: Int = 1
+
+ try {
+ throw new Exception("abc")
+ } catch {
+ case _: Exception => println("success 1")
+ case _ => println("failed 1")
+ }
+
+ try {
+ throw new Exception("abc")
+ } catch {
+ case e: Exception => println("success 2")
+ case _ => println("failed 2")
+ }
+
+ try {
+ throw new Exception("abc")
+ } catch {
+ case e: Exception if e.getMessage == "abc" => println("success 3")
+ case _ => println("failed 3")
+ }
+
+ try {
+ throw new Exception("abc")
+ } catch {
+ case e: Exception if e.getMessage == "" => println("failed 4")
+ case _ => println("success 4")
+ }
+
+ try {
+ throw EX
+ } catch {
+ case EX => println("success 5")
+ case _ => println("failed 5")
+ }
+
+ try {
+ throw new EX.InnerException
+ } catch {
+ case _: EX.InnerException => println("success 6")
+ case _ => println("failed 6")
+ }
+
+ try {
+ throw new NullPointerException
+ } catch {
+ case _: NullPointerException | _:IOException => println("success 7")
+ case _ => println("failed 7")
+ }
+
+ try {
+ throw new ExceptionTrait {}
+ } catch {
+ case _: ExceptionTrait => println("success 8")
+ case _ => println("failed 8")
+ }
+
+ try {
+ throw new IOException
+ } catch {
+ case _: ExceptionType => println("success 9.1")
+ case _ => println("failed 9.1")
+ }
+
+ traitTest() // test 9.2
+
+ def testThrow(throwIt: => Unit): Unit = {
+ try {
+ throwIt
+ } catch {
+ // These cases will be compiled as catch cases
+ case e: NullPointerException => println("NullPointerException")
+ case e: IndexOutOfBoundsException => println("IndexOutOfBoundsException")
+ case _: NoSuchElementException => println("NoSuchElementException")
+ case _: EX.InnerException => println("InnerException")
+ // All the following will be compiled as a match
+ case IAE(msg) => println("IllegalArgumentException: " + msg)
+ case _: ExceptionTrait => println("ExceptionTrait")
+ case e: IOException if e.getMessage == null => println("IOException")
+ case _: NullPointerException | _:IOException => println("NullPointerException | IOException")
+ case `a` => println("`a`")
+ case EX => println("EX")
+ case e: IllegalArgumentException => println("IllegalArgumentException")
+ case _: ClassCastException => println("ClassCastException")
+ }
+ }
+
+ testThrow(throw new IllegalArgumentException("abc"))
+ testThrow(throw new IllegalArgumentException())
+ testThrow(throw new IOException("abc"))
+ testThrow(throw new NoSuchElementException())
+ testThrow(throw EX)
+ testThrow(throw new EX.InnerException)
+ testThrow(throw new NullPointerException())
+ testThrow(throw new ExceptionTrait {})
+ testThrow(throw a.asInstanceOf[Throwable])
+ try {
+ testThrow(throw new TimeoutException)
+ println("TimeoutException did not escape")
+ } catch {
+ case _: TimeoutException => println("TimeoutException escaped")
+ }
+ }
+
+}