diff options
author | Philipp Haller <hallerp@gmail.com> | 2012-11-27 03:07:11 -0800 |
---|---|---|
committer | Philipp Haller <hallerp@gmail.com> | 2012-11-27 03:07:11 -0800 |
commit | 653a46dc7244aa06f3282150ca4e9694e14a2948 (patch) | |
tree | a4f4a2b8a6848577bae9f82805ffb61fe9b25e00 | |
parent | 456fd6e561a52f34040d9af041cc2b74880e5579 (diff) | |
parent | 7c93a9e0e288b55027646016913c7368732d54e4 (diff) | |
download | scala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.tar.gz scala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.tar.bz2 scala-async-653a46dc7244aa06f3282150ca4e9694e14a2948.zip |
Merge pull request #45 from phaller/ticket/33-by-name-2
Ticket/33 by name 2
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 48 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 9 | ||||
-rw-r--r-- | src/test/scala/scala/async/TreeInterrogation.scala | 20 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 86 |
4 files changed, 139 insertions, 24 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index a2d21f6..2fa96c9 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -81,14 +81,14 @@ private[async] final case class AnfTransform[C <: Context](c: C) { treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) case x => super.transform(x) } - case Ident(name) => + case Ident(name) => if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) else tree - case Select(fun, name) => + case Select(fun, name) => if (renamed(tree.symbol)) { treeCopy.Select(tree, transform(fun), tree.symbol.name) } else super.transform(tree) - case _ => super.transform(tree) + case _ => super.transform(tree) } } } @@ -170,12 +170,12 @@ private[async] final case class AnfTransform[C <: Context](c: C) { vd.setPos(pos) vd } + } - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) - vd.setPos(pos) - vd - } + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) + vd.setPos(pos) + vd } private object anf { @@ -187,16 +187,26 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case Apply(fun, args) if containsAwait => + case Apply(fun, args) if containsAwait => + checkForAwaitInNonPrimaryParamSection(fun, args) + // we an assume that no await call appears in a by-name argument position, // this has already been checked. - + val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) - val argLists = args map inline.transformToList + def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") + val argLists: List[List[Tree]] = args.zipWithIndex map { + case (arg, i) if isByName(i) || isSafeToInline(arg) => List(arg) + case (arg@Ident(name), _) if isAwaitRef(name) => List(arg) // not typed, so it eludes the check in `isSafeToInline` + case (arg, i) => inline.transformToList(arg) match { + case stats :+ expr => + val valDef = defineVal(name.arg(i), expr, arg.pos) + stats ::: List(valDef, Ident(valDef.name)) + } + } val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)) - case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) @@ -259,4 +269,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } + def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) { + // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite + // *all* argument lists in the correct order to preserve semantics. + fun match { + case Apply(fun1, _) => + fun1.tpe match { + case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe => + c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.") + case _ => + } + case _ => + } + + } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 7571f88..23f39d2 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -30,6 +30,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val ifRes = "ifres" val await = "await" val bindSuffix = "$bind" + def arg(i: Int) = "arg" + i def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) @@ -244,4 +245,12 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } + def isSafeToInline(tree: Tree) = { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + object treeInfo extends { + val global: symtab.type = symtab + } with reflect.internal.TreeInfo + val castTree = tree.asInstanceOf[symtab.Tree] + treeInfo.isExprSafeToInline(castTree) + } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index ecb1bca..b22faa9 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -70,17 +70,15 @@ object TreeInterrogation extends App { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") val tree = tb.parse( - """ - | import scala.async.Async.{async, await} - | import scala.concurrent.{future, ExecutionContext, Await} - | import ExecutionContext.Implicits._ - | import scala.concurrent.duration._ - | - | try { - | val f = async { throw new Exception("problem") } - | Await.result(f, 1.second) - | } catch { - | case ex: Exception if ex.getMessage == "problem" => // okay + """ import scala.async.AsyncId.{async, await} + | def foo(a: Int, b: Int) = (a, b) + | val result = async { + | var i = 0 + | def next() = { + | i += 1; + | i + | } + | foo(next(), await(next())) | } | () | """.stripMargin) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 6dd4db7..529386b 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -163,7 +163,7 @@ class AnfTransformSpec { val result = AsyncId.async { val x = "" match { case "" if false => AsyncId.await(1) + 1 - case _ => 2 + AsyncId.await(1) + case _ => 2 + AsyncId.await(1) } val y = x "" match { @@ -205,4 +205,88 @@ class AnfTransformSpec { } result mustBe (true) } + + @Test + def byNameExpressionsArentLifted() { + import _root_.scala.async.AsyncId.{async, await} + def foo(ignored: => Any, b: Int) = b + val result = async { + foo(???, await(1)) + } + result mustBe (1) + } + + @Test + def evaluationOrderRespected() { + import scala.async.AsyncId.{async, await} + def foo(a: Int, b: Int) = (a, b) + val result = async { + var i = 0 + def next() = { + i += 1; + i + } + foo(next(), await(next())) + } + result mustBe ((1, 2)) + } + + @Test + def awaitNotAllowedInNonPrimaryParamSection1() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo(primary: Any)(i: Int) = i + | async { + | foo(???)(await(0)) + | } + """.stripMargin + } + } + + @Test + def awaitNotAllowedInNonPrimaryParamSection2() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo[T](primary: Any)(i: Int) = i + | async { + | foo[Int](???)(await(0)) + | } + """.stripMargin + } + } + + @Test + def namedArgumentsRespectEvaluationOrder() { + import scala.async.AsyncId.{async, await} + def foo(a: Int, b: Int) = (a, b) + val result = async { + var i = 0 + def next() = { + i += 1; + i + } + foo(b = next(), a = await(next())) + } + result mustBe ((2, 1)) + } + + @Test + def namedAndDefaultArgumentsRespectEvaluationOrder() { + import scala.async.AsyncId.{async, await} + var i = 0 + def next() = { + i += 1; + i + } + def foo(a: Int = next(), b: Int = next()) = (a, b) + async { + foo(b = await(next())) + } mustBe ((2, 1)) + i = 0 + async { + foo(a = await(next())) + } mustBe ((1, 2)) + } } |