aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2012-11-27 03:07:11 -0800
committerPhilipp Haller <hallerp@gmail.com>2012-11-27 03:07:11 -0800
commit653a46dc7244aa06f3282150ca4e9694e14a2948 (patch)
treea4f4a2b8a6848577bae9f82805ffb61fe9b25e00
parent456fd6e561a52f34040d9af041cc2b74880e5579 (diff)
parent7c93a9e0e288b55027646016913c7368732d54e4 (diff)
downloadscala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.tar.gz
scala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.tar.bz2
scala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.zip
Merge pull request #45 from phaller/ticket/33-by-name-2
Ticket/33 by name 2
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala48
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala9
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala20
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala86
4 files changed, 139 insertions, 24 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index a2d21f6..2fa96c9 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -81,14 +81,14 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template])
case x => super.transform(x)
}
- case Ident(name) =>
+ case Ident(name) =>
if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name)
else tree
- case Select(fun, name) =>
+ case Select(fun, name) =>
if (renamed(tree.symbol)) {
treeCopy.Select(tree, transform(fun), tree.symbol.name)
} else super.transform(tree)
- case _ => super.transform(tree)
+ case _ => super.transform(tree)
}
}
}
@@ -170,12 +170,12 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
vd.setPos(pos)
vd
}
+ }
- private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
- val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs)
- vd.setPos(pos)
- vd
- }
+ private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
+ val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs)
+ vd.setPos(pos)
+ vd
}
private object anf {
@@ -187,16 +187,26 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
val stats :+ expr = inline.transformToList(qual)
stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))
- case Apply(fun, args) if containsAwait =>
+ case Apply(fun, args) if containsAwait =>
+ checkForAwaitInNonPrimaryParamSection(fun, args)
+
// we an assume that no await call appears in a by-name argument position,
// this has already been checked.
-
+ val isByName: (Int) => Boolean = utils.isByName(fun)
val funStats :+ simpleFun = inline.transformToList(fun)
- val argLists = args map inline.transformToList
+ def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$")
+ val argLists: List[List[Tree]] = args.zipWithIndex map {
+ case (arg, i) if isByName(i) || isSafeToInline(arg) => List(arg)
+ case (arg@Ident(name), _) if isAwaitRef(name) => List(arg) // not typed, so it eludes the check in `isSafeToInline`
+ case (arg, i) => inline.transformToList(arg) match {
+ case stats :+ expr =>
+ val valDef = defineVal(name.arg(i), expr, arg.pos)
+ stats ::: List(valDef, Ident(valDef.name))
+ }
+ }
val allArgStats = argLists flatMap (_.init)
val simpleArgs = argLists map (_.last)
funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol))
-
case Block(stats, expr) if containsAwait =>
inline.transformToList(stats :+ expr)
@@ -259,4 +269,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
}
}
+ def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) {
+ // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite
+ // *all* argument lists in the correct order to preserve semantics.
+ fun match {
+ case Apply(fun1, _) =>
+ fun1.tpe match {
+ case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe =>
+ c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.")
+ case _ =>
+ }
+ case _ =>
+ }
+
+ }
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 7571f88..23f39d2 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -30,6 +30,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val ifRes = "ifres"
val await = "await"
val bindSuffix = "$bind"
+ def arg(i: Int) = "arg" + i
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
@@ -244,4 +245,12 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
}
}
+ def isSafeToInline(tree: Tree) = {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ object treeInfo extends {
+ val global: symtab.type = symtab
+ } with reflect.internal.TreeInfo
+ val castTree = tree.asInstanceOf[symtab.Tree]
+ treeInfo.isExprSafeToInline(castTree)
+ }
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index ecb1bca..b22faa9 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -70,17 +70,15 @@ object TreeInterrogation extends App {
val cm = reflect.runtime.currentMirror
val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
val tree = tb.parse(
- """
- | import scala.async.Async.{async, await}
- | import scala.concurrent.{future, ExecutionContext, Await}
- | import ExecutionContext.Implicits._
- | import scala.concurrent.duration._
- |
- | try {
- | val f = async { throw new Exception("problem") }
- | Await.result(f, 1.second)
- | } catch {
- | case ex: Exception if ex.getMessage == "problem" => // okay
+ """ import scala.async.AsyncId.{async, await}
+ | def foo(a: Int, b: Int) = (a, b)
+ | val result = async {
+ | var i = 0
+ | def next() = {
+ | i += 1;
+ | i
+ | }
+ | foo(next(), await(next()))
| }
| ()
| """.stripMargin)
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 6dd4db7..529386b 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -163,7 +163,7 @@ class AnfTransformSpec {
val result = AsyncId.async {
val x = "" match {
case "" if false => AsyncId.await(1) + 1
- case _ => 2 + AsyncId.await(1)
+ case _ => 2 + AsyncId.await(1)
}
val y = x
"" match {
@@ -205,4 +205,88 @@ class AnfTransformSpec {
}
result mustBe (true)
}
+
+ @Test
+ def byNameExpressionsArentLifted() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo(ignored: => Any, b: Int) = b
+ val result = async {
+ foo(???, await(1))
+ }
+ result mustBe (1)
+ }
+
+ @Test
+ def evaluationOrderRespected() {
+ import scala.async.AsyncId.{async, await}
+ def foo(a: Int, b: Int) = (a, b)
+ val result = async {
+ var i = 0
+ def next() = {
+ i += 1;
+ i
+ }
+ foo(next(), await(next()))
+ }
+ result mustBe ((1, 2))
+ }
+
+ @Test
+ def awaitNotAllowedInNonPrimaryParamSection1() {
+ expectError("implementation restriction: await may only be used in the first parameter list.") {
+ """
+ | import _root_.scala.async.AsyncId.{async, await}
+ | def foo(primary: Any)(i: Int) = i
+ | async {
+ | foo(???)(await(0))
+ | }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def awaitNotAllowedInNonPrimaryParamSection2() {
+ expectError("implementation restriction: await may only be used in the first parameter list.") {
+ """
+ | import _root_.scala.async.AsyncId.{async, await}
+ | def foo[T](primary: Any)(i: Int) = i
+ | async {
+ | foo[Int](???)(await(0))
+ | }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def namedArgumentsRespectEvaluationOrder() {
+ import scala.async.AsyncId.{async, await}
+ def foo(a: Int, b: Int) = (a, b)
+ val result = async {
+ var i = 0
+ def next() = {
+ i += 1;
+ i
+ }
+ foo(b = next(), a = await(next()))
+ }
+ result mustBe ((2, 1))
+ }
+
+ @Test
+ def namedAndDefaultArgumentsRespectEvaluationOrder() {
+ import scala.async.AsyncId.{async, await}
+ var i = 0
+ def next() = {
+ i += 1;
+ i
+ }
+ def foo(a: Int = next(), b: Int = next()) = (a, b)
+ async {
+ foo(b = await(next()))
+ } mustBe ((2, 1))
+ i = 0
+ async {
+ foo(a = await(next()))
+ } mustBe ((1, 2))
+ }
}