diff options
-rw-r--r-- | src/dotty/tools/dotc/Compiler.scala | 3 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/PatternMatcher.scala | 76 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/TryCatchPatterns.scala | 99 | ||||
-rw-r--r-- | tests/neg/tryPatternMatchError.scala | 35 | ||||
-rw-r--r-- | tests/run/tryPatternMatch.check | 20 | ||||
-rw-r--r-- | tests/run/tryPatternMatch.scala | 139 |
6 files changed, 297 insertions, 75 deletions
diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index 3844f42a7..ce9280d82 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -57,7 +57,8 @@ class Compiler { new TailRec, // Rewrite tail recursion to loops new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new ClassOf), // Expand `Predef.classOf` calls. - List(new PatternMatcher, // Compile pattern matches + List(new TryCatchPatterns, // Compile cases in try/catch + new PatternMatcher, // Compile pattern matches new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts new CrossCastAnd, // Normalize selections involving intersection types. diff --git a/src/dotty/tools/dotc/transform/PatternMatcher.scala b/src/dotty/tools/dotc/transform/PatternMatcher.scala index fd89696a8..974053769 100644 --- a/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -1,6 +1,8 @@ package dotty.tools.dotc package transform +import scala.language.postfixOps + import TreeTransforms._ import core.Denotations._ import core.SymDenotations._ @@ -53,19 +55,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans translated.ensureConforms(tree.tpe) } - - override def transformTry(tree: tpd.Try)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { - val selector = - ctx.newSymbol(ctx.owner, ctx.freshName("ex").toTermName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = tree.pos) - val sel = Ident(selector.termRef).withPos(tree.pos) - val rethrow = tpd.CaseDef(EmptyTree, EmptyTree, Throw(ref(selector))) - val newCases = tpd.CaseDef( - Bind(selector, Underscore(selector.info).withPos(tree.pos)), - EmptyTree, - transformMatch(tpd.Match(sel, tree.cases ::: rethrow :: Nil))) - cpy.Try(tree)(tree.expr, newCases :: Nil, tree.finalizer) - } - class Translator(implicit ctx: Context) { def translator = { @@ -1264,27 +1253,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans t } - /** Is this pattern node a catch-all or type-test pattern? */ - def isCatchCase(cdef: CaseDef) = cdef match { - case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) => - isSimpleThrowable(tpt.tpe) - case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) => - isSimpleThrowable(tpt.tpe) - case _ => - isDefaultCase(cdef) - } - - private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match { - case tp @ TypeRef(pre, _) => - val sym = tp.symbol - (pre == NoPrefix || pre.widen.typeSymbol.isStatic) && - (sym.derivesFrom(defn.ThrowableClass)) && /* bq */ !(sym is Flags.Trait) - case _ => - false - } - - - /** Implement a pattern match by turning its cases (including the implicit failure case) * into the corresponding (monadic) extractors, and combining them with the `orElse` combinator. * @@ -1335,46 +1303,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans Block(List(ValDef(selectorSym, sel)), combined) } - // return list of typed CaseDefs that are supported by the backend (typed/bind/wildcard) - // we don't have a global scrutinee -- the caught exception must be bound in each of the casedefs - // there's no need to check the scrutinee for null -- "throw null" becomes "throw new NullPointerException" - // try to simplify to a type-based switch, or fall back to a catch-all case that runs a normal pattern match - // unlike translateMatch, we type our result before returning it - /*def translateTry(caseDefs: List[CaseDef], pt: Type, pos: Position): List[CaseDef] = - // if they're already simple enough to be handled by the back-end, we're done - if (caseDefs forall isCatchCase) caseDefs - else { - val swatches = { // switch-catches - val bindersAndCases = caseDefs map { caseDef => - // generate a fresh symbol for each case, hoping we'll end up emitting a type-switch (we don't have a global scrut there) - // if we fail to emit a fine-grained switch, have to do translateCase again with a single scrutSym (TODO: uniformize substitution on treemakers so we can avoid this) - val caseScrutSym = freshSym(pos, pureType(defn.ThrowableType)) - (caseScrutSym, propagateSubstitution(translateCase(caseScrutSym, pt)(caseDef), EmptySubstitution)) - } - - for(cases <- emitTypeSwitch(bindersAndCases, pt).toList - if cases forall isCatchCase; // must check again, since it's not guaranteed -- TODO: can we eliminate this? e.g., a type test could test for a trait or a non-trivial prefix, which are not handled by the back-end - cse <- cases) yield /*fixerUpper(matchOwner, pos)*/(cse).asInstanceOf[CaseDef] - } - - val catches = if (swatches.nonEmpty) swatches else { - val scrutSym = freshSym(pos, pureType(defn.ThrowableType)) - val casesNoSubstOnly = caseDefs map { caseDef => (propagateSubstitution(translateCase(scrutSym, pt)(caseDef), EmptySubstitution))} - - val exSym = freshSym(pos, pureType(defn.ThrowableType), "ex") - - List( - CaseDef( - Bind(exSym, Ident(??? /*nme.WILDCARD*/)), // TODO: does this need fixing upping? - EmptyTree, - combineCasesNoSubstOnly(ref(exSym), scrutSym, casesNoSubstOnly, pt, matchOwner, Some((scrut: Symbol) => Throw(ref(exSym)))) - ) - ) - } - - /*typer.typedCases(*/catches/*, defn.ThrowableType, WildcardType)*/ - }*/ - /** The translation of `pat if guard => body` has two aspects: * 1) the substitution due to the variables bound by patterns * 2) the combination of the extractor calls using `flatMap`. diff --git a/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala new file mode 100644 index 000000000..9a6ecef51 --- /dev/null +++ b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala @@ -0,0 +1,99 @@ +package dotty.tools.dotc +package transform + +import core.Symbols._ +import core.StdNames._ +import ast.Trees._ +import core.Types._ +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} +import dotty.tools.dotc.util.Positions.Position + +/** Compiles the cases that can not be handled by primitive catch cases as a common pattern match. + * + * The following code: + * ``` + * try { <code> } + * catch { + * <tryCases> // Cases that can be handled by catch + * <patternMatchCases> // Cases starting with first one that can't be handled by catch + * } + * ``` + * will become: + * ``` + * try { <code> } + * catch { + * <tryCases> + * case e => e match { + * <patternMatchCases> + * } + * } + * ``` + * + * Cases that are not supported include: + * - Applies and unapplies + * - Idents + * - Alternatives + * - `case _: T =>` where `T` is not `Throwable` + * + */ +class TryCatchPatterns extends MiniPhaseTransform { + import dotty.tools.dotc.ast.tpd._ + + def phaseName: String = "tryCatchPatterns" + + override def runsAfter = Set(classOf[ElimRepeated]) + + override def checkPostCondition(tree: Tree)(implicit ctx: Context): Unit = tree match { + case Try(_, cases, _) => + cases.foreach { + case CaseDef(Typed(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.") + case CaseDef(Bind(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.") + case c => + assert(isDefaultCase(c), "Pattern in Try should be Bind, Typed or default case.") + } + case _ => + } + + override def transformTry(tree: Try)(implicit ctx: Context, info: TransformerInfo): Tree = { + val (tryCases, patternMatchCases) = tree.cases.span(isCatchCase) + val fallbackCase = mkFallbackPatterMatchCase(patternMatchCases, tree.pos) + cpy.Try(tree)(cases = tryCases ++ fallbackCase) + } + + /** Is this pattern node a catch-all or type-test pattern? */ + private def isCatchCase(cdef: CaseDef)(implicit ctx: Context): Boolean = cdef match { + case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) => isSimpleThrowable(tpt.tpe) + case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) => isSimpleThrowable(tpt.tpe) + case _ => isDefaultCase(cdef) + } + + private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match { + case tp @ TypeRef(pre, _) => + (pre == NoPrefix || pre.widen.typeSymbol.isStatic) && // Does not require outer class check + !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM + tp.derivesFrom(defn.ThrowableClass) + case _ => + false + } + + private def mkFallbackPatterMatchCase(patternMatchCases: List[CaseDef], pos: Position)( + implicit ctx: Context, info: TransformerInfo): Option[CaseDef] = { + if (patternMatchCases.isEmpty) None + else { + val exName = ctx.freshName("ex").toTermName + val fallbackSelector = + ctx.newSymbol(ctx.owner, exName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = pos) + val sel = Ident(fallbackSelector.termRef).withPos(pos) + val rethrow = CaseDef(EmptyTree, EmptyTree, Throw(ref(fallbackSelector))) + Some(CaseDef( + Bind(fallbackSelector, Underscore(fallbackSelector.info).withPos(pos)), + EmptyTree, + transformFollowing(Match(sel, patternMatchCases ::: rethrow :: Nil))) + ) + } + } + +} diff --git a/tests/neg/tryPatternMatchError.scala b/tests/neg/tryPatternMatchError.scala new file mode 100644 index 000000000..fe12a6232 --- /dev/null +++ b/tests/neg/tryPatternMatchError.scala @@ -0,0 +1,35 @@ +import java.io.IOException +import java.lang.NullPointerException +import java.lang.IllegalArgumentException + +object IAE { + def unapply(e: Exception): Option[String] = + if (e.isInstanceOf[IllegalArgumentException]) Some(e.getMessage) + else None +} + +object EX extends Exception + +trait ExceptionTrait extends Exception + +object Test { + def main(args: Array[String]): Unit = { + var a: Int = 1 + try { + throw new IllegalArgumentException() + } catch { + case e: IOException if e.getMessage == null => + case e: NullPointerException => + case e: IndexOutOfBoundsException => + case _: NoSuchElementException => + case _: ExceptionTrait => + case _: NoSuchElementException if a <= 1 => + case _: NullPointerException | _:IOException => + case `a` => // This case should probably emmit an error + case e: Int => // error + case EX => + case IAE(msg) => + case e: IllegalArgumentException => + } + } +} diff --git a/tests/run/tryPatternMatch.check b/tests/run/tryPatternMatch.check new file mode 100644 index 000000000..44f7b7d5a --- /dev/null +++ b/tests/run/tryPatternMatch.check @@ -0,0 +1,20 @@ +success 1 +success 2 +success 3 +success 4 +success 5 +success 6 +success 7 +success 8 +success 9.1 +success 9.2 +IllegalArgumentException: abc +IllegalArgumentException +NullPointerException | IOException +NoSuchElementException +EX +InnerException +NullPointerException +ExceptionTrait +ClassCastException +TimeoutException escaped diff --git a/tests/run/tryPatternMatch.scala b/tests/run/tryPatternMatch.scala new file mode 100644 index 000000000..06b469d4d --- /dev/null +++ b/tests/run/tryPatternMatch.scala @@ -0,0 +1,139 @@ +import java.io.IOException +import java.util.concurrent.TimeoutException + +object IAE { + def unapply(e: Exception): Option[String] = + if (e.isInstanceOf[IllegalArgumentException] && e.getMessage != null) Some(e.getMessage) + else None +} + +object EX extends Exception { + val msg = "a" + class InnerException extends Exception(msg) +} + +trait ExceptionTrait extends Exception + +trait TestTrait { + type ExceptionType <: Exception + + def traitTest(): Unit = { + try { + throw new IOException + } catch { + case _: ExceptionType => println("success 9.2") + case _ => println("failed 9.2") + } + } +} + +object Test extends TestTrait { + type ExceptionType = IOException + + def main(args: Array[String]): Unit = { + var a: Int = 1 + + try { + throw new Exception("abc") + } catch { + case _: Exception => println("success 1") + case _ => println("failed 1") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception => println("success 2") + case _ => println("failed 2") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception if e.getMessage == "abc" => println("success 3") + case _ => println("failed 3") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception if e.getMessage == "" => println("failed 4") + case _ => println("success 4") + } + + try { + throw EX + } catch { + case EX => println("success 5") + case _ => println("failed 5") + } + + try { + throw new EX.InnerException + } catch { + case _: EX.InnerException => println("success 6") + case _ => println("failed 6") + } + + try { + throw new NullPointerException + } catch { + case _: NullPointerException | _:IOException => println("success 7") + case _ => println("failed 7") + } + + try { + throw new ExceptionTrait {} + } catch { + case _: ExceptionTrait => println("success 8") + case _ => println("failed 8") + } + + try { + throw new IOException + } catch { + case _: ExceptionType => println("success 9.1") + case _ => println("failed 9.1") + } + + traitTest() // test 9.2 + + def testThrow(throwIt: => Unit): Unit = { + try { + throwIt + } catch { + // These cases will be compiled as catch cases + case e: NullPointerException => println("NullPointerException") + case e: IndexOutOfBoundsException => println("IndexOutOfBoundsException") + case _: NoSuchElementException => println("NoSuchElementException") + case _: EX.InnerException => println("InnerException") + // All the following will be compiled as a match + case IAE(msg) => println("IllegalArgumentException: " + msg) + case _: ExceptionTrait => println("ExceptionTrait") + case e: IOException if e.getMessage == null => println("IOException") + case _: NullPointerException | _:IOException => println("NullPointerException | IOException") + case `a` => println("`a`") + case EX => println("EX") + case e: IllegalArgumentException => println("IllegalArgumentException") + case _: ClassCastException => println("ClassCastException") + } + } + + testThrow(throw new IllegalArgumentException("abc")) + testThrow(throw new IllegalArgumentException()) + testThrow(throw new IOException("abc")) + testThrow(throw new NoSuchElementException()) + testThrow(throw EX) + testThrow(throw new EX.InnerException) + testThrow(throw new NullPointerException()) + testThrow(throw new ExceptionTrait {}) + testThrow(throw a.asInstanceOf[Throwable]) + try { + testThrow(throw new TimeoutException) + println("TimeoutException did not escape") + } catch { + case _: TimeoutException => println("TimeoutException escaped") + } + } + +} |