From fe9a0023e685a2924cba10ec738e8babe9e7bd7b Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 26 Nov 2012 18:20:00 +0100 Subject: Fix semantics of by-name application - If we lift one arg, we must lift them all. This preserves evaluation order. - But, never lift an by-name arg Addresses the first half of #33. --- src/main/scala/scala/async/AnfTransform.scala | 24 ++++++++++++++-------- .../scala/async/run/anf/AnfTransformSpec.scala | 22 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index a2d21f6..449ea7b 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -170,12 +170,12 @@ private[async] final case class AnfTransform[C <: Context](c: C) { vd.setPos(pos) vd } + } - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) - vd.setPos(pos) - vd - } + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) + vd.setPos(pos) + vd } private object anf { @@ -190,10 +190,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { case Apply(fun, args) if containsAwait => // we an assume that no await call appears in a by-name argument position, // this has already been checked. - + val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) - val argLists = args map inline.transformToList - val allArgStats = argLists flatMap (_.init) + val argLists: List[List[Tree]] = args.zipWithIndex map { + case (arg, i) if isByName(i) => List(arg) + case (arg, i) => inline.transformToList(arg) match { + case stats :+ expr => + val valDef = defineVal(s"arg$i", expr, arg.pos) + stats ::: List(valDef, Ident(valDef.name)) + case xs => xs + } + } + val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 6dd4db7..595fa6c 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -205,4 +205,26 @@ class AnfTransformSpec { } result mustBe (true) } + + @Test + def byNameExpressionsArentLifted() { + import _root_.scala.async.AsyncId.{async, await} + def foo(ignored: => Any, b: Int) = b + val result = async { + foo(???, await(1)) + } + result mustBe (1) + } + + @Test + def evaluationOrderRespected() { + import scala.async.AsyncId.{async, await} + def foo(a: Int, b: Int) = (a, b) + val result = async { + var i = 0 + def next() = {i += 1; i} + foo(next(), await(next())) + } + result mustBe ((1, 2)) + } } -- cgit v1.2.3 From 9ad8783d39848d2c5dc5a2a73ac8d54c2859dd0e Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 26 Nov 2012 18:26:11 +0100 Subject: Disallow await in non-primary param sections. We can allow it, but we need to treat nested Apply trees holistically, in order to lift out all the arguments and maintain the correct evaluation order. Fixes #33. --- src/main/scala/scala/async/AnfTransform.scala | 29 +++++++++++++++----- .../scala/async/run/anf/AnfTransformSpec.scala | 32 ++++++++++++++++++++-- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 449ea7b..6b17b94 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -81,14 +81,14 @@ private[async] final case class AnfTransform[C <: Context](c: C) { treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) case x => super.transform(x) } - case Ident(name) => + case Ident(name) => if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) else tree - case Select(fun, name) => + case Select(fun, name) => if (renamed(tree.symbol)) { treeCopy.Select(tree, transform(fun), tree.symbol.name) } else super.transform(tree) - case _ => super.transform(tree) + case _ => super.transform(tree) } } } @@ -187,7 +187,9 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case Apply(fun, args) if containsAwait => + case Apply(fun, args) if containsAwait => + checkForAwaitInNonPrimaryParamSection(fun, args) + // we an assume that no await call appears in a by-name argument position, // this has already been checked. val isByName: (Int) => Boolean = utils.isByName(fun) @@ -198,13 +200,12 @@ private[async] final case class AnfTransform[C <: Context](c: C) { case stats :+ expr => val valDef = defineVal(s"arg$i", expr, arg.pos) stats ::: List(valDef, Ident(valDef.name)) - case xs => xs + case xs => xs } } - val allArgStats = argLists flatMap (_.init) + val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)) - case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) @@ -267,4 +268,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } + def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) { + // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite + // *all* argument lists in the correct order to preserve semantics. + fun match { + case Apply(fun1, _) => + fun1.tpe match { + case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe => + c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.") + case _ => + } + case _ => + } + + } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 595fa6c..f274068 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -163,7 +163,7 @@ class AnfTransformSpec { val result = AsyncId.async { val x = "" match { case "" if false => AsyncId.await(1) + 1 - case _ => 2 + AsyncId.await(1) + case _ => 2 + AsyncId.await(1) } val y = x "" match { @@ -222,9 +222,37 @@ class AnfTransformSpec { def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 - def next() = {i += 1; i} + def next() = { + i += 1; i + } foo(next(), await(next())) } result mustBe ((1, 2)) } + + @Test + def awaitNotAllowedInNonPrimaryParamSection1() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo(primary: Any)(i: Int) = i + | async { + | foo(???)(await(0)) + | } + """.stripMargin + } + } + + @Test + def awaitNotAllowedInNonPrimaryParamSection2() { + expectError("implementation restriction: await may only be used in the first parameter list.") { + """ + | import _root_.scala.async.AsyncId.{async, await} + | def foo[T](primary: Any)(i: Int) = i + | async { + | foo[Int](???)(await(0)) + | } + """.stripMargin + } + } } -- cgit v1.2.3 From 38c362b45fa3f5ae006ebdebaaf163c701313967 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 26 Nov 2012 18:58:03 +0100 Subject: Test cases for named and default args. --- .../scala/async/run/anf/AnfTransformSpec.scala | 36 +++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index f274068..529386b 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -223,7 +223,8 @@ class AnfTransformSpec { val result = async { var i = 0 def next() = { - i += 1; i + i += 1; + i } foo(next(), await(next())) } @@ -255,4 +256,37 @@ class AnfTransformSpec { """.stripMargin } } + + @Test + def namedArgumentsRespectEvaluationOrder() { + import scala.async.AsyncId.{async, await} + def foo(a: Int, b: Int) = (a, b) + val result = async { + var i = 0 + def next() = { + i += 1; + i + } + foo(b = next(), a = await(next())) + } + result mustBe ((2, 1)) + } + + @Test + def namedAndDefaultArgumentsRespectEvaluationOrder() { + import scala.async.AsyncId.{async, await} + var i = 0 + def next() = { + i += 1; + i + } + def foo(a: Int = next(), b: Int = next()) = (a, b) + async { + foo(b = await(next())) + } mustBe ((2, 1)) + i = 0 + async { + foo(a = await(next())) + } mustBe ((1, 2)) + } } -- cgit v1.2.3 From 7c93a9e0e288b55027646016913c7368732d54e4 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Tue, 27 Nov 2012 08:36:26 +0100 Subject: No need to extract vals for inlinable args. We use `isExprSafeToInline` from the non-public reflection API to check. In addtion, we now that an untyped Ident("await$N") is also an inlinable expression. --- src/main/scala/scala/async/AnfTransform.scala | 9 +++++---- src/main/scala/scala/async/TransformUtils.scala | 9 +++++++++ src/test/scala/scala/async/TreeInterrogation.scala | 20 +++++++++----------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 6b17b94..2fa96c9 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -194,13 +194,14 @@ private[async] final case class AnfTransform[C <: Context](c: C) { // this has already been checked. val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) + def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") val argLists: List[List[Tree]] = args.zipWithIndex map { - case (arg, i) if isByName(i) => List(arg) - case (arg, i) => inline.transformToList(arg) match { + case (arg, i) if isByName(i) || isSafeToInline(arg) => List(arg) + case (arg@Ident(name), _) if isAwaitRef(name) => List(arg) // not typed, so it eludes the check in `isSafeToInline` + case (arg, i) => inline.transformToList(arg) match { case stats :+ expr => - val valDef = defineVal(s"arg$i", expr, arg.pos) + val valDef = defineVal(name.arg(i), expr, arg.pos) stats ::: List(valDef, Ident(valDef.name)) - case xs => xs } } val allArgStats = argLists flatMap (_.init) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 7571f88..23f39d2 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -30,6 +30,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val ifRes = "ifres" val await = "await" val bindSuffix = "$bind" + def arg(i: Int) = "arg" + i def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) @@ -244,4 +245,12 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } + def isSafeToInline(tree: Tree) = { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + object treeInfo extends { + val global: symtab.type = symtab + } with reflect.internal.TreeInfo + val castTree = tree.asInstanceOf[symtab.Tree] + treeInfo.isExprSafeToInline(castTree) + } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index ecb1bca..b22faa9 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -70,17 +70,15 @@ object TreeInterrogation extends App { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") val tree = tb.parse( - """ - | import scala.async.Async.{async, await} - | import scala.concurrent.{future, ExecutionContext, Await} - | import ExecutionContext.Implicits._ - | import scala.concurrent.duration._ - | - | try { - | val f = async { throw new Exception("problem") } - | Await.result(f, 1.second) - | } catch { - | case ex: Exception if ex.getMessage == "problem" => // okay + """ import scala.async.AsyncId.{async, await} + | def foo(a: Int, b: Int) = (a, b) + | val result = async { + | var i = 0 + | def next() = { + | i += 1; + | i + | } + | foo(next(), await(next())) | } | () | """.stripMargin) -- cgit v1.2.3