diff options
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 29 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 32 |
2 files changed, 52 insertions, 9 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 449ea7b..6b17b94 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -81,14 +81,14 @@ private[async] final case class AnfTransform[C <: Context](c: C) { treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) case x => super.transform(x) } - case Ident(name) => + case Ident(name) => if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) else tree - case Select(fun, name) => + case Select(fun, name) => if (renamed(tree.symbol)) { treeCopy.Select(tree, transform(fun), tree.symbol.name) } else super.transform(tree) - case _ => super.transform(tree) + case _ => super.transform(tree) } } } @@ -187,7 +187,9 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case Apply(fun, args) if containsAwait => + case Apply(fun, args) if containsAwait => + checkForAwaitInNonPrimaryParamSection(fun, args) + // we an assume that no await call appears in a by-name argument position, // this has already been checked. val isByName: (Int) => Boolean = utils.isByName(fun) @@ -198,13 +200,12 @@ private[async] final case class AnfTransform[C <: Context](c: C) { case stats :+ expr => val valDef = defineVal(s"arg$i", expr, arg.pos) stats ::: List(valDef, Ident(valDef.name)) - case xs => xs + case xs => xs } } - val allArgStats = argLists flatMap (_.init) + val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)) - case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) @@ -267,4 +268,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } + def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) { + // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite + // *all* argument lists in the correct order to preserve semantics. + fun match { + case Apply(fun1, _) => + fun1.tpe match { + case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe => + c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.") + case _ => + } + case _ => + } + + } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 595fa6c..f274068 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -163,7 +163,7 @@ class AnfTransformSpec { val result = AsyncId.async { val x = "" match { case "" if false => AsyncId.await(1) + 1 - case _ => 2 + AsyncId.await(1) + case _ => 2 + AsyncId.await(1) } val y = x "" match { @@ -222,9 +222,37 @@ class AnfTransformSpec { def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 - def next() = {i += 1; i} + def next() = { + i += 1; i + } foo(next(), await(next())) } result mustBe ((1, 2)) } + + @Test + def awaitNotAllowedInNonPrimaryParamSection1() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo(primary: Any)(i: Int) = i + | async { + | foo(???)(await(0)) + | } + """.stripMargin + } + } + + @Test + def awaitNotAllowedInNonPrimaryParamSection2() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo[T](primary: Any)(i: Int) = i + | async { + | foo[Int](???)(await(0)) + | } + """.stripMargin + } + } } |