diff options
author | phaller <hallerp@gmail.com> | 2012-10-31 16:45:42 +0100 |
---|---|---|
committer | phaller <hallerp@gmail.com> | 2012-11-02 11:09:44 +0100 |
commit | f22998d343c01951254fc2020731986fb3219ff0 (patch) | |
tree | b5ae469b65156d3f4d3453dac213bfe6499cb985 /src/async/library/scala/async/Async.scala | |
parent | c15af267fc7ed5dc7ef40428d738dd5679606f66 (diff) | |
download | scala-async-f22998d343c01951254fc2020731986fb3219ff0.tar.gz scala-async-f22998d343c01951254fc2020731986fb3219ff0.tar.bz2 scala-async-f22998d343c01951254fc2020731986fb3219ff0.zip |
Fix for #1861: Add fall-back to CPS for all unsupported uses of await
Diffstat (limited to 'src/async/library/scala/async/Async.scala')
-rw-r--r-- | src/async/library/scala/async/Async.scala | 156 |
1 files changed, 99 insertions, 57 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index d3e0904..1d9b2d9 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -7,8 +7,13 @@ import language.experimental.macros import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer -import scala.concurrent.{ Future, Promise } +import scala.concurrent.{ Future, Promise, ExecutionContext, future } +import ExecutionContext.Implicits.global import scala.util.control.NonFatal +import scala.util.continuations.{ shift, reset, cpsParam } + +/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */ +class FallbackToCpsException extends scala.util.control.ControlThrowable /* * @author Philipp Haller @@ -19,6 +24,16 @@ object Async extends AsyncUtils { def await[T](awaitable: Future[T]): T = ??? + /* Fall back for `await` when it is called at an unsupported position. + */ + def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] = + shift { + (k: (T => U)) => + awaitable onComplete { + case tr => p.success(k(tr.get)) + } + } + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = { import c.universe._ import Flag._ @@ -26,48 +41,49 @@ object Async extends AsyncUtils { val builder = new ExprBuilder[c.type](c) val awaitMethod = awaitSym(c) - body.tree match { - case Block(stats, expr) => - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) - - vprintln(s"states of current method (${ asyncBlockBuilder.asyncStates }):") - asyncBlockBuilder.asyncStates foreach vprintln + try { + body.tree match { + case Block(stats, expr) => + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) - val handlerExpr = asyncBlockBuilder.mkHandlerExpr() - - vprintln(s"GENERATED handler expr:") - vprintln(handlerExpr) - - val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { - val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), - List(asyncBlockBuilder.asyncStates.last.body)) - builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) - } - - vprintln("GENERATED handler for last state:") - vprintln(handlerForLastState) - - val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList - - val unitIdent = Ident(definitions.UnitClass) - - val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent, - Try(Apply(Select( - Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), - newTermName("apply")), List(Ident(newTermName("state")))), - List( - CaseDef( - Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), - EmptyTree, - Block(List( - Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), - Literal(Constant(()))))), EmptyTree)) - - val methodBody = reify { - val result = Promise[T]() - var state = 0 - - /* + vprintln(s"states of current method (${asyncBlockBuilder.asyncStates}):") + asyncBlockBuilder.asyncStates foreach vprintln + + val handlerExpr = asyncBlockBuilder.mkHandlerExpr() + + vprintln(s"GENERATED handler expr:") + vprintln(handlerExpr) + + val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { + val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), + List(asyncBlockBuilder.asyncStates.last.body)) + builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) + } + + vprintln("GENERATED handler for last state:") + vprintln(handlerForLastState) + + val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList + + val unitIdent = Ident(definitions.UnitClass) + + val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent, + Try(Apply(Select( + Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), + newTermName("apply")), List(Ident(newTermName("state")))), + List( + CaseDef( + Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + EmptyTree, + Block(List( + Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), + Literal(Constant(()))))), EmptyTree)) + + val methodBody = reify { + val result = Promise[T]() + var state = 0 + + /* def resume(): Unit = { try { (handlerExpr.splice orElse handlerForLastState.splice)(state) @@ -77,24 +93,50 @@ object Async extends AsyncUtils { } resume() */ - - c.Expr(Block( - localVarTrees :+ resumeFunTree, - Apply(Ident(newTermName("resume")), List()) - )).splice - - result.future - } - //vprintln("ASYNC: Generated method body:") - //vprintln(c.universe.showRaw(methodBody)) - //vprintln(c.universe.show(methodBody)) - methodBody + c.Expr(Block( + localVarTrees :+ resumeFunTree, + Apply(Ident(newTermName("resume")), List()))).splice + + result.future + } - case _ => - // issue error message + //vprintln("ASYNC: Generated method body:") + //vprintln(c.universe.showRaw(methodBody)) + //vprintln(c.universe.show(methodBody)) + methodBody + + case _ => + // issue error message + reify { + sys.error("expression not supported by async") + } + } + } catch { + case _: FallbackToCpsException => + // replace `await` invocations with `awaitCps` invocations + val awaitReplacer = new Transformer { + val awaitCpsMethod = awaitCpsSym(c) + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod => + val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p"))) + + case _ => + super.transform(tree) + } + } + + val newBody = awaitReplacer.transform(body.tree) + reify { - sys.error("expression not supported by async") + val p = Promise[T]() + future { + reset { + c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice + } + } + p.future } } } |