aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbuild.sh2
-rw-r--r--src/async/library/scala/async/Async.scala156
-rw-r--r--src/async/library/scala/async/AsyncUtils.scala6
-rw-r--r--src/async/library/scala/async/ExprBuilder.scala11
-rw-r--r--test/pending/run/fallback0/MinimalScalaTest.scala102
-rw-r--r--test/pending/run/fallback0/fallback0-manual.scala72
-rw-r--r--test/pending/run/fallback0/fallback0.scala49
7 files changed, 340 insertions, 58 deletions
diff --git a/build.sh b/build.sh
index e543143..90d590e 100755
--- a/build.sh
+++ b/build.sh
@@ -1,4 +1,4 @@
#!/bin/bash
scalac -version
mkdir -p classes
-scalac -d classes -deprecation -feature src/async/library/scala/async/*.scala
+scalac -P:continuations:enable -d classes -deprecation -feature src/async/library/scala/async/*.scala
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala
index d3e0904..1d9b2d9 100644
--- a/src/async/library/scala/async/Async.scala
+++ b/src/async/library/scala/async/Async.scala
@@ -7,8 +7,13 @@ import language.experimental.macros
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
-import scala.concurrent.{ Future, Promise }
+import scala.concurrent.{ Future, Promise, ExecutionContext, future }
+import ExecutionContext.Implicits.global
import scala.util.control.NonFatal
+import scala.util.continuations.{ shift, reset, cpsParam }
+
+/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */
+class FallbackToCpsException extends scala.util.control.ControlThrowable
/*
* @author Philipp Haller
@@ -19,6 +24,16 @@ object Async extends AsyncUtils {
def await[T](awaitable: Future[T]): T = ???
+ /* Fall back for `await` when it is called at an unsupported position.
+ */
+ def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] =
+ shift {
+ (k: (T => U)) =>
+ awaitable onComplete {
+ case tr => p.success(k(tr.get))
+ }
+ }
+
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = {
import c.universe._
import Flag._
@@ -26,48 +41,49 @@ object Async extends AsyncUtils {
val builder = new ExprBuilder[c.type](c)
val awaitMethod = awaitSym(c)
- body.tree match {
- case Block(stats, expr) =>
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000)
-
- vprintln(s"states of current method (${ asyncBlockBuilder.asyncStates }):")
- asyncBlockBuilder.asyncStates foreach vprintln
+ try {
+ body.tree match {
+ case Block(stats, expr) =>
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000)
- val handlerExpr = asyncBlockBuilder.mkHandlerExpr()
-
- vprintln(s"GENERATED handler expr:")
- vprintln(handlerExpr)
-
- val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
- val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")),
- List(asyncBlockBuilder.asyncStates.last.body))
- builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree))
- }
-
- vprintln("GENERATED handler for last state:")
- vprintln(handlerForLastState)
-
- val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList
-
- val unitIdent = Ident(definitions.UnitClass)
-
- val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent,
- Try(Apply(Select(
- Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)),
- newTermName("apply")), List(Ident(newTermName("state")))),
- List(
- CaseDef(
- Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))),
- EmptyTree,
- Block(List(
- Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))),
- Literal(Constant(()))))), EmptyTree))
-
- val methodBody = reify {
- val result = Promise[T]()
- var state = 0
-
- /*
+ vprintln(s"states of current method (${asyncBlockBuilder.asyncStates}):")
+ asyncBlockBuilder.asyncStates foreach vprintln
+
+ val handlerExpr = asyncBlockBuilder.mkHandlerExpr()
+
+ vprintln(s"GENERATED handler expr:")
+ vprintln(handlerExpr)
+
+ val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
+ val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")),
+ List(asyncBlockBuilder.asyncStates.last.body))
+ builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree))
+ }
+
+ vprintln("GENERATED handler for last state:")
+ vprintln(handlerForLastState)
+
+ val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList
+
+ val unitIdent = Ident(definitions.UnitClass)
+
+ val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), unitIdent,
+ Try(Apply(Select(
+ Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)),
+ newTermName("apply")), List(Ident(newTermName("state")))),
+ List(
+ CaseDef(
+ Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))),
+ EmptyTree,
+ Block(List(
+ Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))),
+ Literal(Constant(()))))), EmptyTree))
+
+ val methodBody = reify {
+ val result = Promise[T]()
+ var state = 0
+
+ /*
def resume(): Unit = {
try {
(handlerExpr.splice orElse handlerForLastState.splice)(state)
@@ -77,24 +93,50 @@ object Async extends AsyncUtils {
}
resume()
*/
-
- c.Expr(Block(
- localVarTrees :+ resumeFunTree,
- Apply(Ident(newTermName("resume")), List())
- )).splice
-
- result.future
- }
- //vprintln("ASYNC: Generated method body:")
- //vprintln(c.universe.showRaw(methodBody))
- //vprintln(c.universe.show(methodBody))
- methodBody
+ c.Expr(Block(
+ localVarTrees :+ resumeFunTree,
+ Apply(Ident(newTermName("resume")), List()))).splice
+
+ result.future
+ }
- case _ =>
- // issue error message
+ //vprintln("ASYNC: Generated method body:")
+ //vprintln(c.universe.showRaw(methodBody))
+ //vprintln(c.universe.show(methodBody))
+ methodBody
+
+ case _ =>
+ // issue error message
+ reify {
+ sys.error("expression not supported by async")
+ }
+ }
+ } catch {
+ case _: FallbackToCpsException =>
+ // replace `await` invocations with `awaitCps` invocations
+ val awaitReplacer = new Transformer {
+ val awaitCpsMethod = awaitCpsSym(c)
+ override def transform(tree: Tree): Tree = tree match {
+ case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod =>
+ val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
+ treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p")))
+
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ val newBody = awaitReplacer.transform(body.tree)
+
reify {
- sys.error("expression not supported by async")
+ val p = Promise[T]()
+ future {
+ reset {
+ c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice
+ }
+ }
+ p.future
}
}
}
diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala
index 820541b..adc8c87 100644
--- a/src/async/library/scala/async/AsyncUtils.scala
+++ b/src/async/library/scala/async/AsyncUtils.scala
@@ -22,5 +22,11 @@ trait AsyncUtils {
val tpe = asyncMod.moduleClass.asType.toType
tpe.member(c.universe.newTermName("await"))
}
+
+ protected def awaitCpsSym(c: Context): c.universe.Symbol = {
+ val asyncMod = c.mirror.staticModule("scala.async.Async")
+ val tpe = asyncMod.moduleClass.asType.toType
+ tpe.member(c.universe.newTermName("awaitCps"))
+ }
}
diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala
index 776cc7b..4d068b5 100644
--- a/src/async/library/scala/async/ExprBuilder.scala
+++ b/src/async/library/scala/async/ExprBuilder.scala
@@ -383,6 +383,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
private var remainingBudget = budget
+ /* Fall back to CPS plug-in if tree contains an `await` call. */
+ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
+ case Apply(fun, _) if fun.symbol == awaitMethod => true
+ case _ => false
+ }) throw new FallbackToCpsException
+
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
@@ -396,11 +402,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
stateBuilder = new builder.AsyncStateBuilder(currState)
case ValDef(mods, name, tpt, rhs) =>
+ checkForUnsupportedAwait(rhs)
+
stateBuilder.addVarDef(mods, name, tpt)
stateBuilder += // instead of adding `stat` we add a simple assignment
Assign(Ident(name), c.resetAllAttrs(rhs.duplicate))
case If(cond, thenp, elsep) =>
+ checkForUnsupportedAwait(cond)
+
val ifBudget: Int = remainingBudget / 2
remainingBudget -= ifBudget //TODO test if budget > 0
vprintln(s"ASYNC IF: ifBudget = $ifBudget")
@@ -446,6 +456,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
stateBuilder = new builder.AsyncStateBuilder(currState)
case _ =>
+ checkForUnsupportedAwait(stat)
stateBuilder += stat
}
// complete last state builder (representing the expressions after the last await)
diff --git a/test/pending/run/fallback0/MinimalScalaTest.scala b/test/pending/run/fallback0/MinimalScalaTest.scala
new file mode 100644
index 0000000..91de1fc
--- /dev/null
+++ b/test/pending/run/fallback0/MinimalScalaTest.scala
@@ -0,0 +1,102 @@
+import language.reflectiveCalls
+import language.postfixOps
+import language.implicitConversions
+
+import scala.reflect.{ ClassTag, classTag }
+
+import scala.collection.mutable
+import scala.concurrent.{ Future, Awaitable, CanAwait }
+import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit }
+import scala.concurrent.duration.Duration
+
+
+
+trait Output {
+ val buffer = new StringBuilder
+
+ def bufferPrintln(a: Any): Unit = buffer.synchronized {
+ buffer.append(a.toString + "\n")
+ }
+}
+
+
+trait MinimalScalaTest extends Output {
+
+ val throwables = mutable.ArrayBuffer[Throwable]()
+
+ def check() {
+ if (throwables.nonEmpty) println(buffer.toString)
+ }
+
+ implicit def stringops(s: String) = new {
+
+ def should[U](snippets: =>U): U = {
+ bufferPrintln(s + " should:")
+ snippets
+ }
+
+ def in[U](snippet: =>U): Unit = {
+ try {
+ bufferPrintln("- " + s)
+ snippet
+ bufferPrintln("[OK] Test passed.")
+ } catch {
+ case e: Throwable =>
+ bufferPrintln("[FAILED] " + e)
+ bufferPrintln(e.getStackTrace().mkString("\n"))
+ throwables += e
+ }
+ }
+
+ }
+
+ implicit def objectops(obj: Any) = new {
+
+ def mustBe(other: Any) = assert(obj == other, obj + " is not " + other)
+ def mustEqual(other: Any) = mustBe(other)
+
+ }
+
+ def intercept[T <: Throwable: ClassTag](body: =>Any): T = {
+ try {
+ body
+ throw new Exception("Exception of type %s was not thrown".format(classTag[T]))
+ } catch {
+ case t: Throwable =>
+ if (classTag[T].runtimeClass != t.getClass) throw t
+ else t.asInstanceOf[T]
+ }
+ }
+
+ def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag
+}
+
+
+object TestLatch {
+ val DefaultTimeout = Duration(5, TimeUnit.SECONDS)
+
+ def apply(count: Int = 1) = new TestLatch(count)
+}
+
+
+class TestLatch(count: Int = 1) extends Awaitable[Unit] {
+ private var latch = new CountDownLatch(count)
+
+ def countDown() = latch.countDown()
+ def isOpen: Boolean = latch.getCount == 0
+ def open() = while (!isOpen) countDown()
+ def reset() = latch = new CountDownLatch(count)
+
+ @throws(classOf[TimeoutException])
+ def ready(atMost: Duration)(implicit permit: CanAwait) = {
+ val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS)
+ if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString))
+ this
+ }
+
+ @throws(classOf[Exception])
+ def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
+ ready(atMost)
+ }
+
+}
diff --git a/test/pending/run/fallback0/fallback0-manual.scala b/test/pending/run/fallback0/fallback0-manual.scala
new file mode 100644
index 0000000..611d09d
--- /dev/null
+++ b/test/pending/run/fallback0/fallback0-manual.scala
@@ -0,0 +1,72 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{ reflectiveCalls, postfixOps }
+import scala.concurrent.{ Future, ExecutionContext, future, Await, Promise }
+import scala.concurrent.duration._
+import scala.async.EndTaskException
+import scala.async.Async.{ async, await, awaitCps }
+import scala.util.continuations.reset
+
+object TestManual extends App {
+
+ Fallback0ManualSpec.check()
+
+}
+
+class TestFallback0ManualClass {
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = {
+ val p = Promise[Int]()
+ future { reset {
+ val f = m1(y)
+ var z = 0
+ val res = awaitCps(f, p) + 5
+ if (res > 0) {
+ z = 2
+ } else {
+ z = 4
+ }
+ z
+ } }
+ p.future
+ }
+
+ /* that isn't even supported by current CPS plugin
+ def m3(y: Int): Future[Int] = {
+ val p = Promise[Int]()
+ future { reset {
+ val f = m1(y)
+ var z = 0
+ val res: Option[Int] = Some(5)
+ res match {
+ case None => z = 4
+ case Some(a) => z = awaitCps(f, p) - 10
+ }
+ z
+ } }
+ p.future
+ }
+ */
+}
+
+
+object Fallback0ManualSpec extends MinimalScalaTest {
+
+ "An async method" should {
+ "support await in a simple if-else expression" in {
+ val o = new TestFallback0ManualClass
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe(2)
+ }
+ }
+
+}
diff --git a/test/pending/run/fallback0/fallback0.scala b/test/pending/run/fallback0/fallback0.scala
new file mode 100644
index 0000000..75b0739
--- /dev/null
+++ b/test/pending/run/fallback0/fallback0.scala
@@ -0,0 +1,49 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{ reflectiveCalls, postfixOps }
+import scala.concurrent.{ Future, ExecutionContext, future, Await }
+import scala.concurrent.duration._
+import scala.async.Async.{ async, await, awaitCps }
+
+object Test extends App {
+
+ Fallback0Spec.check()
+
+}
+
+class TestFallback0Class {
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ var z = 0
+ val res = await(f) + 5
+ if (res > 0) {
+ z = 2
+ } else {
+ z = 4
+ }
+ z
+ }
+}
+
+
+object Fallback0Spec extends MinimalScalaTest {
+
+ "An async method" should {
+ "support await in a simple if-else expression" in {
+ val o = new TestFallback0Class
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe(2)
+ }
+ }
+
+}