From 0583f9aa26b12ef8509eb2beca12ee92ba13bec3 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Wed, 27 Feb 2013 19:00:55 +0100 Subject: Fix #12 - Add support for await inside try-catch - `await` can now be used inside the body of `try` and `finally` - using `await` inside the cases of a `catch` is illegal - provides precise error messages ("await must not be used under catch") - adds 9 tests --- src/main/scala/scala/async/AnfTransform.scala | 61 +++--- src/main/scala/scala/async/Async.scala | 9 +- src/main/scala/scala/async/AsyncAnalysis.scala | 12 +- src/main/scala/scala/async/ExprBuilder.scala | 147 +++++++++++--- src/main/scala/scala/async/TransformUtils.scala | 3 + src/test/scala/scala/async/TreeInterrogation.scala | 3 +- src/test/scala/scala/async/neg/NakedAwait.scala | 22 +-- .../scala/scala/async/run/trycatch/TrySpec.scala | 215 +++++++++++++++++++++ 8 files changed, 385 insertions(+), 87 deletions(-) create mode 100644 src/test/scala/scala/async/run/trycatch/TrySpec.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index afcf6bd..bf2d7b2 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -116,44 +116,42 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object inline { def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { + def branchWithAssign(orig: Tree, varDef: ValDef) = orig match { + case Block(stats, expr) => Block(stats, Assign(Ident(varDef.name), expr)) + case _ => Assign(Ident(varDef.name), orig) + } + + def casesWithAssign(cases: List[CaseDef], varDef: ValDef) = cases map { + case cd @ CaseDef(pat, guard, orig) => + attachCopy(cd)(CaseDef(pat, guard, branchWithAssign(orig, varDef))) + } + val stats :+ expr = anf.transformToList(tree) expr match { + // if type of if-else/try/match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + case If(_, _, _) | Try(_, _, _) | Match(_, _) if expr.tpe =:= definitions.UnitTpe => + stats :+ expr :+ Literal(Constant(())) + case Apply(fun, args) if isAwait(fun) => val valDef = defineVal(name.await, expr, tree.pos) stats :+ valDef :+ Ident(valDef.name) case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) - case _ => Assign(Ident(varDef.name), orig) - } - val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) - } + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + val ifWithAssign = If(cond, branchWithAssign(thenp, varDef), branchWithAssign(elsep, varDef)) + stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) + + case Try(body, catches, finalizer) => + val varDef = defineVar(name.tryRes, expr.tpe, tree.pos) + val tryWithAssign = Try(branchWithAssign(body, varDef), casesWithAssign(catches, varDef), finalizer) + stats :+ varDef :+ tryWithAssign :+ Ident(varDef.name) case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => - attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))) - case cd@CaseDef(pat, guard, body) => - attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body))) - } - val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign)) - stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) - } + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign(cases, varDef))) + stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) + case _ => stats :+ expr } @@ -220,6 +218,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(rhs) stats :+ attachCopy(tree)(Assign(lhs, expr)) + case Try(body, catches, finalizer) if containsAwait => + val stats :+ expr = inline.transformToList(body) + val tryType = c.typeCheck(Try(Block(stats, expr), catches, finalizer)).tpe + List(attachCopy(tree)(Try(Block(stats, expr), catches, finalizer)).setType(tryType)) + case If(cond, thenp, elsep) if containsAwait => val condStats :+ condExpr = inline.transformToList(cond) val thenBlock = inline.transformToBlock(thenp) diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index eab0ee4..272cc48 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -120,6 +120,13 @@ abstract class AsyncBase { val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) + + // the stack of currently active exception handlers + val handlers = ValDef(Modifiers(Flag.MUTABLE), name.handlers, TypeTree(typeOf[List[PartialFunction[Throwable, Unit]]]), (reify { List() }).tree) + + // the exception that is currently in-flight or `null` otherwise + val exception = ValDef(Modifiers(Flag.MUTABLE), name.exception, TypeTree(typeOf[Throwable]), Literal(Constant(null))) + val applyDefDef: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) val applyBody = asyncBlock.onCompleteHandler @@ -132,7 +139,7 @@ abstract class AsyncBase { val applyBody = asyncBlock.onCompleteHandler DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) } - List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) + List(utils.emptyConstructor, stateVar, result, execContext, handlers, exception) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { Template(List(stateMachineType), emptyValDef, body) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 9184960..7c667c3 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -76,16 +76,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy } override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait + def containsAwait(t: Tree) = t exists isAwait tree match { - case Try(_, _, _) if containsAwait => - reportUnsupportedAwait(tree, "try/catch") + case Try(_, catches, _) if catches exists containsAwait => + reportUnsupportedAwait(tree, "catch") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => c.abort(tree.pos, "lazy vals are illegal within an async block") - case _ => + case _ => super.traverse(tree) } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 180e7b9..22337d4 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -38,11 +38,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, excState: Option[Int]) extends AsyncState { def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply, excState) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -51,9 +51,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int, excState: Option[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats) + mkHandlerCase(state, stats, excState) override val toString: String = s"AsyncStateWithoutAwait #$state" @@ -63,13 +63,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * handler will unconditionally transition to `nestState`.`` */ final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) + awaitable: Awaitable, excState: Option[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) + mkHandlerCase(state, stats :+ callOnComplete, excState) } override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { @@ -97,7 +97,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: Block(List(tryGetTree, mkStateTree(nextState), mkResumeApply): _*) ) - Some(mkHandlerCase(state, List(ifIsFailureTree))) + Some(mkHandlerCase(state, List(ifIsFailureTree), excState)) } override val toString: String = @@ -106,8 +106,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /* * Builder for a single state of an async method. + * + * The `excState` parameter is implicit, so that it is passed implicitly + * when `AsyncBlockBuilder` creates new `AsyncStateBuilder`s. */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { + final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name])(implicit excState: Option[Int]) { /* Statements preceding an await call. */ private val stats = ListBuffer[c.Tree]() /** The state of the target of a LabelDef application (while loop jump) */ @@ -134,12 +137,17 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: nextState: Int): AsyncState = { val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable, excState) + } + + def resultWithoutAwait(): AsyncState = { + this += mkResumeApply + new AsyncStateWithoutAwait(stats.toList, state, excState) } def resultSimple(nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, excState) } def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { @@ -148,7 +156,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val cond = renameReset(condTree) def mkBranch(state: Int) = Block(mkStateTree(state), mkResumeApply) this += If(cond, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } /** @@ -173,12 +181,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } // 2. insert changed match tree at the end of the current state this += Match(renameReset(scrutTree), newCases) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } def resultWithLabel(startLabelState: Int): AsyncState = { this += Block(mkStateTree(startLabelState), mkResumeApply) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } override def toString: String = { @@ -190,14 +198,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * - * @param stats a list of expressions - * @param expr the last expression of the block - * @param startState the start state - * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + * @param toRename a `Map` for renaming the given key symbols to the mangled value names + * @param excState the state to continue with in case of an exception + * @param parentExcState the state to continue with in case of an exception not handled by the current exception handler */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { + final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, private val toRename: Map[Symbol, c.Name], + parentExcState: Option[Int] = None)(implicit excState: Option[Int]) { val asyncStates = ListBuffer[AsyncState]() var stateBuilder = new AsyncStateBuilder(startState, toRename) @@ -209,9 +219,10 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: case _ => false }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException - def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int, + excState: Option[Int] = None, parentExcState: Option[Int] = None) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename, parentExcState)(excState) } import stateAssigner.nextState @@ -250,6 +261,60 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, toRename) + case Try(block, catches, finalizer) if stat exists isAwait => + val tryStartState = nextState() + val afterTryState = nextState() + val ehState = nextState() + val finalizerState = if (!finalizer.isEmpty) Some(nextState()) else None + + // complete current state so that it continues with tryStartState + asyncStates += stateBuilder.resultWithLabel(tryStartState) + + if (!finalizer.isEmpty) { + val builder = nestedBlockBuilder(finalizer, finalizerState.get, afterTryState) + asyncStates ++= builder.asyncStates + } + + // create handler state + def handlersDot(m: String) = Select(Ident(name.handlers), m) + val exceptionExpr = c.Expr[Throwable](Ident(name.exception)) + // handler state does not have active exception handler --> None + val handlerStateBuilder = new AsyncStateBuilder(ehState, toRename)(None) + + val parentExpr: c.Expr[Unit] = + if (parentExcState.isEmpty) reify { throw exceptionExpr.splice } + else c.Expr[Unit](mkStateTree(parentExcState.get)) + + val handlerExpr = reify { + val h = c.Expr[PartialFunction[Throwable, Unit]](handlersDot("head")).splice + c.Expr[Unit](Assign(Ident(name.handlers), handlersDot("tail"))).splice + + if (h isDefinedAt exceptionExpr.splice) { + h(exceptionExpr.splice) + c.Expr[Unit](mkStateTree(if (!finalizer.isEmpty) finalizerState.get else afterTryState)).splice + } else { + parentExpr.splice + } + } + + handlerStateBuilder += handlerExpr.tree + asyncStates += handlerStateBuilder.resultWithoutAwait() + + val ehName = newTermName("handlerPF$" + ehState) + val partFunAssign = ValDef(Modifiers(), ehName, TypeTree(typeOf[PartialFunction[Throwable, Unit]]), Match(EmptyTree, catches)) + val newHandler = c.Expr[PartialFunction[Throwable, Unit]](Ident(ehName)) + val handlersIdent = c.Expr[List[PartialFunction[Throwable, Unit]]](Ident(name.handlers)) + val pushedHandlers = reify { handlersIdent.splice.+:(newHandler.splice) } + val pushAssign = Assign(Ident(name.handlers), pushedHandlers.tree) + + val (tryStats, tryExpr) = statsAndExpr(block) + val builder = nestedBlockBuilder(Block(partFunAssign :: pushAssign :: tryStats, tryExpr), + tryStartState, if (!finalizer.isEmpty) finalizerState.get else afterTryState, Some(ehState), excState) + asyncStates ++= builder.asyncStates + + currState = afterTryState + stateBuilder = new AsyncStateBuilder(currState, toRename) + case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) @@ -302,7 +367,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val startState = stateAssigner.nextState() val endState = Int.MaxValue - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename)(None) new AsyncBlock { def asyncStates = blockBuilder.asyncStates.toList @@ -313,7 +378,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val lastStateBody = c.Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) - mkHandlerCase(lastState.state, rhs.tree) + mkHandlerCase(lastState.state, rhs.tree, None) } asyncStates.toList match { case s :: Nil => @@ -386,9 +451,35 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: private def mkStateTree(nextState: Int): c.Tree = Assign(Ident(name.state), c.literal(nextState).tree) - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs: _*)) + private def mkHandlerCase(num: Int, rhs: List[c.Tree], excState: Option[Int]): CaseDef = + mkHandlerCase(num, Block(rhs: _*), excState) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + /* Generates `case` clause with wrapping try-catch: + * + * case `num` => + * try { + * rhs + * } catch { + * case NonFatal(t) => + * exception$async = t + * state$async = excState.get + * resume$async() + * } + */ + private def mkHandlerCase(num: Int, rhs: c.Tree, excState: Option[Int]): CaseDef = { + val rhsWithTry = + if (excState.isEmpty) rhs + else Try(rhs, + List( + CaseDef( + Apply(Ident(defn.NonFatalClass), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + EmptyTree, + Block(List( + Assign(Ident(name.exception), Ident(newTermName("t"))), + mkStateTree(excState.get), + mkResumeApply + ), c.literalUnit.tree))), EmptyTree + ) + CaseDef(c.literal(num).tree, EmptyTree, rhsWithTry) + } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 2d3f210..00fa430 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -29,8 +29,11 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val tr = newTermName("tr") val matchRes = "matchres" val ifRes = "ifres" + val tryRes = "tryRes" val await = "await" val bindSuffix = "$bind" + val handlers = suffixedName("handlers") + val exception = suffixedName("exception") def arg(i: Int) = "arg" + i diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 93cfdf5..1876556 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -40,8 +40,7 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2", "handlers$async", "exception$async")) val defDefs = tree1.collect { case t: Template => diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index c3537ec..f297bed 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -102,19 +102,9 @@ class NakedAwait { } } - @Test - def tryBody() { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.AsyncId._ - | async { try { await(false) } catch { case _ => } } - """.stripMargin - } - } - @Test def catchBody() { - expectError("await must not be used under a try/catch.") { + expectError("await must not be used under a catch.") { """ | import _root_.scala.async.AsyncId._ | async { try { () } catch { case _ => await(false) } } @@ -122,16 +112,6 @@ class NakedAwait { } } - @Test - def finallyBody() { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.AsyncId._ - | async { try { () } finally { await(false) } } - """.stripMargin - } - } - @Test def nestedMethod() { expectError("await must not be used under a nested method.") { diff --git a/src/test/scala/scala/async/run/trycatch/TrySpec.scala b/src/test/scala/scala/async/run/trycatch/TrySpec.scala new file mode 100644 index 0000000..4f6e93c --- /dev/null +++ b/src/test/scala/scala/async/run/trycatch/TrySpec.scala @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package trycatch + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class TrySpec { + + @Test + def tryCatch1() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + try { + val y = await(xxx) + xxx = xxx + 1 + y + } catch { + case e: Exception => + assert(false) + } + xxx + } + assert(result == 1) + } + + @Test + def tryCatch2() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + try { + val y = await(xxx) + throw new Exception("test msg") + assert(false) + xxx = xxx + 1 + y + } catch { + case e: Exception => + assert(e.getMessage == "test msg") + xxx = 7 + } + xxx + } + assert(result == 7) + } + + @Test + def nestedTry1() { + import AsyncId._ + + val result = async { + var xxx = 0 + try { + try { + val y = await(xxx) + throw new IllegalArgumentException("msg") + assert(false) + y + 2 + } catch { + case iae: IllegalArgumentException => + xxx = 6 + } + } catch { + case nsee: NoSuchElementException => + xxx = 7 + } + xxx + } + assert(result == 6) + } + + @Test + def nestedTry2() { + import AsyncId._ + + val result = async { + var xxx = 0 + try { + try { + val y = await(xxx) + throw new NoSuchElementException("msg") + assert(false) + y + 2 + } catch { + case iae: IllegalArgumentException => + xxx = 6 + } + } catch { + case nsee: NoSuchElementException => + xxx = 7 + } + xxx + } + assert(result == 7) + } + + @Test + def tryAsExpr() { + import AsyncId._ + + val result = async { + val xxx: Int = 0 + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } + } + assert(result == 2) + } + + @Test + def tryFinally1() { + import AsyncId._ + + var xxx: Int = 0 + val result = async { + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } finally { + xxx = 5 + } + } + assert(result == 2) + assert(xxx == 5) + } + + @Test + def tryFinally2() { + import AsyncId._ + + var xxx: Int = 0 + val result = async { + try { + val y = await(xxx) + throw new Exception("msg") + assert(false) + y + 2 + } catch { + case e: Exception => + xxx + 4 + } finally { + xxx = 6 + } + } + assert(result == 4) + assert(xxx == 6) + } + + @Test + def tryFinallyAwait1() { + import AsyncId._ + + var xxx: Int = 0 + var uuu: Int = 10 + val result = async { + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } finally { + val v = await(uuu) + xxx = v + } + } + assert(result == 2) + assert(xxx == 10) + } + + @Test + def tryFinallyAwait2() { + import AsyncId._ + + var xxx: Int = 0 + var uuu: Int = 10 + val result = async { + try { + val y = await(xxx) + throw new Exception("msg") + assert(false) + y + 2 + } catch { + case e: Exception => + xxx + 4 + } finally { + val v = await(uuu) + xxx = v + } + } + assert(result == 4) + assert(xxx == 10) + } + +} -- cgit v1.2.3