aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2013-04-13 02:20:32 -0700
committerPhilipp Haller <hallerp@gmail.com>2013-04-13 02:20:32 -0700
commitc26ba9db06a88dcd384e7dd1f0450206eef6f064 (patch)
tree54f458e8e9061c8537a0972ca827a56cc0ce9548
parent1a8b72a0d1ee16ddcff637df57c2b22e2976b853 (diff)
parentffd7b9649cde61553f16c25f64c3072ecfa6c713 (diff)
downloadscala-async-c26ba9db06a88dcd384e7dd1f0450206eef6f064.tar.gz
scala-async-c26ba9db06a88dcd384e7dd1f0450206eef6f064.tar.bz2
scala-async-c26ba9db06a88dcd384e7dd1f0450206eef6f064.zip
Merge pull request #10 from phaller/topic/cps-dep
Remove CPS dependency from default async implementation
-rw-r--r--src/main/scala/scala/async/Async.scala37
-rw-r--r--src/main/scala/scala/async/AsyncWithCPSFallback.scala31
-rw-r--r--src/main/scala/scala/async/FutureSystem.scala8
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala3
-rw-r--r--src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala90
-rw-r--r--src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala20
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsync.scala21
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala21
-rw-r--r--src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala31
-rw-r--r--src/test/scala/scala/async/run/cps/CPSSpec.scala49
10 files changed, 244 insertions, 67 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 8323ac5..5a71ebe 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -6,7 +6,6 @@ package scala.async
import scala.language.experimental.macros
import scala.reflect.macros.Context
-import scala.util.continuations.{cpsParam, reset}
object Async extends AsyncBase {
@@ -61,9 +60,7 @@ abstract class AsyncBase {
@deprecated("`await` must be enclosed in an `async` block", "0.1")
def await[T](awaitable: futureSystem.Fut[T]): T = ???
- def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ???
-
- def fallbackEnabled = false
+ protected[async] def fallbackEnabled = false
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
@@ -72,7 +69,8 @@ abstract class AsyncBase {
val utils = TransformUtils[c.type](c)
import utils.{name, defn}
- if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) {
+ analyzer.reportUnsupportedAwaits(body.tree)
+
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
@@ -162,35 +160,6 @@ abstract class AsyncBase {
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
code
- } else {
- // replace `await` invocations with `awaitFallback` invocations
- val awaitReplacer = new Transformer {
- override def transform(tree: Tree): Tree = tree match {
- case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await =>
- val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
- treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result))
- case _ =>
- super.transform(tree)
- }
- }
- val newBody = awaitReplacer.transform(body.tree)
-
- val resetBody = reify {
- reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice }
- }
-
- val futureSystemOps = futureSystem.mkOps(c)
- val code = {
- val tree = Block(List(
- ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree),
- futureSystemOps.spawn(resetBody.tree)
- ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree)
- c.Expr[futureSystem.Fut[T]](tree)
- }
-
- AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}")
- code
- }
}
def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
diff --git a/src/main/scala/scala/async/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/AsyncWithCPSFallback.scala
deleted file mode 100644
index 39e43a3..0000000
--- a/src/main/scala/scala/async/AsyncWithCPSFallback.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-package scala.async
-
-import scala.language.experimental.macros
-
-import scala.reflect.macros.Context
-import scala.util.continuations._
-
-object AsyncWithCPSFallback extends AsyncBase {
-
- import scala.concurrent.{Future, ExecutionContext}
- import ExecutionContext.Implicits.global
-
- lazy val futureSystem = ScalaConcurrentFutureSystem
- type FS = ScalaConcurrentFutureSystem.type
-
- /* Fall-back for `await` when it is called at an unsupported position.
- */
- override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] =
- shift {
- (k: (T => U)) =>
- awaitable onComplete {
- case tr => p.success(k(tr.get))
- }
- }
-
- override def fallbackEnabled = true
-
- def async[T](body: T) = macro asyncImpl[T]
-
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
-}
diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala
index f0b4653..a050bec 100644
--- a/src/main/scala/scala/async/FutureSystem.scala
+++ b/src/main/scala/scala/async/FutureSystem.scala
@@ -54,6 +54,8 @@ trait FutureSystem {
def spawn(tree: context.Tree): context.Tree =
future(context.Expr[Unit](tree))(execContext).tree
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]]
}
def mkOps(c: Context): Ops { val context: c.type }
@@ -101,6 +103,10 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
prom.splice.complete(value.splice)
context.literalUnit.splice
}
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify {
+ future.splice.asInstanceOf[Fut[A]]
+ }
}
}
@@ -145,5 +151,7 @@ object IdentityFutureSystem extends FutureSystem {
prom.splice.a = value.splice.get
context.literalUnit.splice
}
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
}
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 38c33a4..090a334 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -181,8 +181,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
}
- val Async_await = asyncMember("await")
- val Async_awaitFallback = asyncMember("awaitFallback")
+ val Async_await = asyncMember("await")
}
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
new file mode 100644
index 0000000..a669cfa
--- /dev/null
+++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package continuations
+
+import scala.language.experimental.macros
+
+import scala.reflect.macros.Context
+import scala.util.continuations._
+
+trait AsyncBaseWithCPSFallback extends AsyncBase {
+
+ /* Fall-back for `await` using CPS plugin.
+ *
+ * Note: This method is public, but is intended only for internal use.
+ */
+ def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[futureSystem.Fut[Any]]
+
+ override protected[async] def fallbackEnabled = true
+
+ /* Implements `async { ... }` using the CPS plugin.
+ */
+ protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ import c.universe._
+
+ def lookupMember(name: String) = {
+ val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback")
+ val tpe = asyncTrait.asType.toType
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
+ }
+
+ AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl")
+
+ val utils = TransformUtils[c.type](c)
+ val futureSystemOps = futureSystem.mkOps(c)
+ val awaitSym = utils.defn.Async_await
+ val awaitFallbackSym = lookupMember("awaitFallback")
+
+ // replace `await` invocations with `awaitFallback` invocations
+ val awaitReplacer = new Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym =>
+ val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe)))
+ treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)))
+ case _ =>
+ super.transform(tree)
+ }
+ }
+ val bodyWithAwaitFallback = awaitReplacer.transform(body.tree)
+
+ /* generate an expression that looks like this:
+ reset {
+ val f = future { ... }
+ ...
+ val x = awaitFallback(f)
+ ...
+ future { expr }
+ }.asInstanceOf[Future[T]]
+ */
+
+ val bodyWithFuture = {
+ val tree = bodyWithAwaitFallback match {
+ case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr))
+ case expr => futureSystemOps.spawn(expr)
+ }
+ c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate))
+ }
+
+ val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify {
+ reset { bodyWithFuture.splice }
+ }
+ val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset)
+
+ AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}")
+ bodyWithCast
+ }
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl")
+
+ val analyzer = AsyncAnalysis[c.type](c, this)
+
+ if (!analyzer.reportUnsupportedAwaits(body.tree))
+ super.asyncImpl[T](c)(body) // no unsupported awaits
+ else
+ cpsBasedAsyncImpl[T](c)(body) // fallback to CPS
+ }
+}
diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
new file mode 100644
index 0000000..fe6e1a6
--- /dev/null
+++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
@@ -0,0 +1,20 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package continuations
+
+import scala.language.experimental.macros
+
+import scala.reflect.macros.Context
+import scala.concurrent.Future
+
+trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback
+
+object AsyncWithCPSFallback extends AsyncWithCPSFallback {
+
+ def async[T](body: T) = macro asyncImpl[T]
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
new file mode 100644
index 0000000..922d1ac
--- /dev/null
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package continuations
+
+import scala.language.experimental.macros
+
+import scala.reflect.macros.Context
+import scala.concurrent.Future
+
+trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback
+
+object CPSBasedAsync extends CPSBasedAsync {
+
+ def async[T](body: T) = macro asyncImpl[T]
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+
+}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
new file mode 100644
index 0000000..4e8ec80
--- /dev/null
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package continuations
+
+import scala.language.experimental.macros
+
+import scala.reflect.macros.Context
+import scala.util.continuations._
+
+/* Specializes `AsyncBaseWithCPSFallback` to always fall back to CPS, yielding a purely CPS-based
+ * implementation of async/await.
+ */
+trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback {
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] =
+ super.cpsBasedAsyncImpl[T](c)(body)
+
+}
diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala
new file mode 100644
index 0000000..018ad05
--- /dev/null
+++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package continuations
+
+import scala.util.continuations._
+import scala.concurrent.{Future, Promise, ExecutionContext}
+
+trait ScalaConcurrentCPSFallback {
+ self: AsyncBaseWithCPSFallback =>
+
+ import ExecutionContext.Implicits.global
+
+ lazy val futureSystem = ScalaConcurrentFutureSystem
+ type FS = ScalaConcurrentFutureSystem.type
+
+ /* Fall-back for `await` when it is called at an unsupported position.
+ */
+ override def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[Future[Any]] =
+ shift {
+ (k: (T => Future[Any])) =>
+ val fr = Promise[Any]()
+ awaitable onComplete {
+ case tr => fr completeWith k(tr.get)
+ }
+ fr.future
+ }
+
+}
diff --git a/src/test/scala/scala/async/run/cps/CPSSpec.scala b/src/test/scala/scala/async/run/cps/CPSSpec.scala
new file mode 100644
index 0000000..b56c6ad
--- /dev/null
+++ b/src/test/scala/scala/async/run/cps/CPSSpec.scala
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package cps
+
+import scala.concurrent.{Future, Promise, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.continuations.CPSBasedAsync._
+import scala.util.continuations._
+
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+@RunWith(classOf[JUnit4])
+class CPSSpec {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(y: Int): Future[Int] = async {
+ val f = future { y + 2 }
+ val f2 = future { y + 3 }
+ val x1 = await(f)
+ val x2 = await(f2)
+ x1 + x2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = future { y + 2 }
+ val res = await(f)
+ if (y > 0) res + 2
+ else res - 2
+ }
+
+ @Test
+ def testCPSFallback() {
+ val fut1 = m1(10)
+ val res1 = Await.result(fut1, 2.seconds)
+ assert(res1 == 25, s"expected 25, got $res1")
+
+ val fut2 = m2(10)
+ val res2 = Await.result(fut2, 2.seconds)
+ assert(res2 == 14, s"expected 14, got $res2")
+ }
+
+}