diff options
-rwxr-xr-x | build.sh | 2 | ||||
-rw-r--r-- | src/async/library/scala/async/Async.scala | 156 | ||||
-rw-r--r-- | src/async/library/scala/async/AsyncUtils.scala | 6 | ||||
-rw-r--r-- | src/async/library/scala/async/ExprBuilder.scala | 11 | ||||
-rw-r--r-- | test/pending/run/fallback0/MinimalScalaTest.scala | 102 | ||||
-rw-r--r-- | test/pending/run/fallback0/fallback0-manual.scala | 72 | ||||
-rw-r--r-- | test/pending/run/fallback0/fallback0.scala | 49 |
7 files changed, 340 insertions, 58 deletions
@@ -1,4 +1,4 @@ #!/bin/bash scalac -version mkdir -p classes -scalac -d classes -deprecation -feature src/async/library/scala/async/*.scala +scalac -P:continuations:enable -d classes -deprecation -feature src/async/library/scala/async/*.scala diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index d3e0904..1d9b2d9 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -7,8 +7,13 @@ import language.experimental.macros import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer -import scala.concurrent.{ Future, Promise } +import scala.concurrent.{ Future, Promise, ExecutionContext, future } +import ExecutionContext.Implicits.global import scala.util.control.NonFatal +import scala.util.continuations.{ shift, reset, cpsParam } + +/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */ +class FallbackToCpsException extends scala.util.control.ControlThrowable /* * @author Philipp Haller @@ -19,6 +24,16 @@ object Async extends AsyncUtils { def await[T](awaitable: Future[T]): T = ??? + /* Fall back for `await` when it is called at an unsupported position. + */ + def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] = + shift { + (k: (T => U)) => + awaitable onComplete { + case tr => p.success(k(tr.get)) + } + } + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = { import c.universe._ import Flag._ @@ -26,48 +41,49 @@ object Async extends AsyncUtils { val builder = new ExprBuilder[c.type](c) val awaitMethod = awaitSym(c) - body.tree match { - case Block(stats, expr) => - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) - - vprintln(s"states of current method (${ asyncBlockBuilder.asyncStates }):") - asyncBlockBuilder.asyncStates foreach vprintln + try { + body.tree match { + case Block(stats, expr) => + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) - val handlerExpr = asyncBlockBuilder.mkHandlerExpr() - - vprintln(s"GENERATED handler expr:") - vprintln(handlerExpr) - - val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { - val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), - List(asyncBlockBuilder.asyncStates.last.body)) - builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) - } - - vprintln("GENERATED handler for last state:") - vprintln(handlerForLastState) - - val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList - - val unitIdent = Ident(definitions.UnitClass) - - val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent, - Try(Apply(Select( - Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), - newTermName("apply")), List(Ident(newTermName("state")))), - List( - CaseDef( - Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), - EmptyTree, - Block(List( - Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), - Literal(Constant(()))))), EmptyTree)) - - val methodBody = reify { - val result = Promise[T]() - var state = 0 - - /* + vprintln(s"states of current method (${asyncBlockBuilder.asyncStates}):") + asyncBlockBuilder.asyncStates foreach vprintln + + val handlerExpr = asyncBlockBuilder.mkHandlerExpr() + + vprintln(s"GENERATED handler expr:") + vprintln(handlerExpr) + + val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { + val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), + List(asyncBlockBuilder.asyncStates.last.body)) + builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) + } + + vprintln("GENERATED handler for last state:") + vprintln(handlerForLastState) + + val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList + + val unitIdent = Ident(definitions.UnitClass) + + val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent, + Try(Apply(Select( + Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), + newTermName("apply")), List(Ident(newTermName("state")))), + List( + CaseDef( + Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + EmptyTree, + Block(List( + Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), + Literal(Constant(()))))), EmptyTree)) + + val methodBody = reify { + val result = Promise[T]() + var state = 0 + + /* def resume(): Unit = { try { (handlerExpr.splice orElse handlerForLastState.splice)(state) @@ -77,24 +93,50 @@ object Async extends AsyncUtils { } resume() */ - - c.Expr(Block( - localVarTrees :+ resumeFunTree, - Apply(Ident(newTermName("resume")), List()) - )).splice - - result.future - } - //vprintln("ASYNC: Generated method body:") - //vprintln(c.universe.showRaw(methodBody)) - //vprintln(c.universe.show(methodBody)) - methodBody + c.Expr(Block( + localVarTrees :+ resumeFunTree, + Apply(Ident(newTermName("resume")), List()))).splice + + result.future + } - case _ => - // issue error message + //vprintln("ASYNC: Generated method body:") + //vprintln(c.universe.showRaw(methodBody)) + //vprintln(c.universe.show(methodBody)) + methodBody + + case _ => + // issue error message + reify { + sys.error("expression not supported by async") + } + } + } catch { + case _: FallbackToCpsException => + // replace `await` invocations with `awaitCps` invocations + val awaitReplacer = new Transformer { + val awaitCpsMethod = awaitCpsSym(c) + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod => + val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p"))) + + case _ => + super.transform(tree) + } + } + + val newBody = awaitReplacer.transform(body.tree) + reify { - sys.error("expression not supported by async") + val p = Promise[T]() + future { + reset { + c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice + } + } + p.future } } } diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala index 820541b..adc8c87 100644 --- a/src/async/library/scala/async/AsyncUtils.scala +++ b/src/async/library/scala/async/AsyncUtils.scala @@ -22,5 +22,11 @@ trait AsyncUtils { val tpe = asyncMod.moduleClass.asType.toType tpe.member(c.universe.newTermName("await")) } + + protected def awaitCpsSym(c: Context): c.universe.Symbol = { + val asyncMod = c.mirror.staticModule("scala.async.Async") + val tpe = asyncMod.moduleClass.asType.toType + tpe.member(c.universe.newTermName("awaitCps")) + } } diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala index 776cc7b..4d068b5 100644 --- a/src/async/library/scala/async/ExprBuilder.scala +++ b/src/async/library/scala/async/ExprBuilder.scala @@ -383,6 +383,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { private var remainingBudget = budget + /* Fall back to CPS plug-in if tree contains an `await` call. */ + def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { + case Apply(fun, _) if fun.symbol == awaitMethod => true + case _ => false + }) throw new FallbackToCpsException + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern @@ -396,11 +402,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { stateBuilder = new builder.AsyncStateBuilder(currState) case ValDef(mods, name, tpt, rhs) => + checkForUnsupportedAwait(rhs) + stateBuilder.addVarDef(mods, name, tpt) stateBuilder += // instead of adding `stat` we add a simple assignment Assign(Ident(name), c.resetAllAttrs(rhs.duplicate)) case If(cond, thenp, elsep) => + checkForUnsupportedAwait(cond) + val ifBudget: Int = remainingBudget / 2 remainingBudget -= ifBudget //TODO test if budget > 0 vprintln(s"ASYNC IF: ifBudget = $ifBudget") @@ -446,6 +456,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { stateBuilder = new builder.AsyncStateBuilder(currState) case _ => + checkForUnsupportedAwait(stat) stateBuilder += stat } // complete last state builder (representing the expressions after the last await) diff --git a/test/pending/run/fallback0/MinimalScalaTest.scala b/test/pending/run/fallback0/MinimalScalaTest.scala new file mode 100644 index 0000000..91de1fc --- /dev/null +++ b/test/pending/run/fallback0/MinimalScalaTest.scala @@ -0,0 +1,102 @@ +import language.reflectiveCalls +import language.postfixOps +import language.implicitConversions + +import scala.reflect.{ ClassTag, classTag } + +import scala.collection.mutable +import scala.concurrent.{ Future, Awaitable, CanAwait } +import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit } +import scala.concurrent.duration.Duration + + + +trait Output { + val buffer = new StringBuilder + + def bufferPrintln(a: Any): Unit = buffer.synchronized { + buffer.append(a.toString + "\n") + } +} + + +trait MinimalScalaTest extends Output { + + val throwables = mutable.ArrayBuffer[Throwable]() + + def check() { + if (throwables.nonEmpty) println(buffer.toString) + } + + implicit def stringops(s: String) = new { + + def should[U](snippets: =>U): U = { + bufferPrintln(s + " should:") + snippets + } + + def in[U](snippet: =>U): Unit = { + try { + bufferPrintln("- " + s) + snippet + bufferPrintln("[OK] Test passed.") + } catch { + case e: Throwable => + bufferPrintln("[FAILED] " + e) + bufferPrintln(e.getStackTrace().mkString("\n")) + throwables += e + } + } + + } + + implicit def objectops(obj: Any) = new { + + def mustBe(other: Any) = assert(obj == other, obj + " is not " + other) + def mustEqual(other: Any) = mustBe(other) + + } + + def intercept[T <: Throwable: ClassTag](body: =>Any): T = { + try { + body + throw new Exception("Exception of type %s was not thrown".format(classTag[T])) + } catch { + case t: Throwable => + if (classTag[T].runtimeClass != t.getClass) throw t + else t.asInstanceOf[T] + } + } + + def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag +} + + +object TestLatch { + val DefaultTimeout = Duration(5, TimeUnit.SECONDS) + + def apply(count: Int = 1) = new TestLatch(count) +} + + +class TestLatch(count: Int = 1) extends Awaitable[Unit] { + private var latch = new CountDownLatch(count) + + def countDown() = latch.countDown() + def isOpen: Boolean = latch.getCount == 0 + def open() = while (!isOpen) countDown() + def reset() = latch = new CountDownLatch(count) + + @throws(classOf[TimeoutException]) + def ready(atMost: Duration)(implicit permit: CanAwait) = { + val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS) + if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString)) + this + } + + @throws(classOf[Exception]) + def result(atMost: Duration)(implicit permit: CanAwait): Unit = { + ready(atMost) + } + +} diff --git a/test/pending/run/fallback0/fallback0-manual.scala b/test/pending/run/fallback0/fallback0-manual.scala new file mode 100644 index 0000000..611d09d --- /dev/null +++ b/test/pending/run/fallback0/fallback0-manual.scala @@ -0,0 +1,72 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +import language.{ reflectiveCalls, postfixOps } +import scala.concurrent.{ Future, ExecutionContext, future, Await, Promise } +import scala.concurrent.duration._ +import scala.async.EndTaskException +import scala.async.Async.{ async, await, awaitCps } +import scala.util.continuations.reset + +object TestManual extends App { + + Fallback0ManualSpec.check() + +} + +class TestFallback0ManualClass { + import ExecutionContext.Implicits.global + + def m1(x: Int): Future[Int] = future { + Thread.sleep(1000) + x + 2 + } + + def m2(y: Int): Future[Int] = { + val p = Promise[Int]() + future { reset { + val f = m1(y) + var z = 0 + val res = awaitCps(f, p) + 5 + if (res > 0) { + z = 2 + } else { + z = 4 + } + z + } } + p.future + } + + /* that isn't even supported by current CPS plugin + def m3(y: Int): Future[Int] = { + val p = Promise[Int]() + future { reset { + val f = m1(y) + var z = 0 + val res: Option[Int] = Some(5) + res match { + case None => z = 4 + case Some(a) => z = awaitCps(f, p) - 10 + } + z + } } + p.future + } + */ +} + + +object Fallback0ManualSpec extends MinimalScalaTest { + + "An async method" should { + "support await in a simple if-else expression" in { + val o = new TestFallback0ManualClass + val fut = o.m2(10) + val res = Await.result(fut, 2 seconds) + res mustBe(2) + } + } + +} diff --git a/test/pending/run/fallback0/fallback0.scala b/test/pending/run/fallback0/fallback0.scala new file mode 100644 index 0000000..75b0739 --- /dev/null +++ b/test/pending/run/fallback0/fallback0.scala @@ -0,0 +1,49 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +import language.{ reflectiveCalls, postfixOps } +import scala.concurrent.{ Future, ExecutionContext, future, Await } +import scala.concurrent.duration._ +import scala.async.Async.{ async, await, awaitCps } + +object Test extends App { + + Fallback0Spec.check() + +} + +class TestFallback0Class { + import ExecutionContext.Implicits.global + + def m1(x: Int): Future[Int] = future { + Thread.sleep(1000) + x + 2 + } + + def m2(y: Int): Future[Int] = async { + val f = m1(y) + var z = 0 + val res = await(f) + 5 + if (res > 0) { + z = 2 + } else { + z = 4 + } + z + } +} + + +object Fallback0Spec extends MinimalScalaTest { + + "An async method" should { + "support await in a simple if-else expression" in { + val o = new TestFallback0Class + val fut = o.m2(10) + val res = Await.result(fut, 2 seconds) + res mustBe(2) + } + } + +} |