From 5a0b1918238cb385401f304b22132f51936d795b Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Wed, 10 Apr 2013 23:52:31 +0200 Subject: Allow await in applications with multiple argument lists Before, we levied an implementation restriction to prevent this. As it turned out, that needlessly prevented use of `await` in the receiver of a multi-param-list application. This commit lifts the restriction altogether, and treats such applications holistically, being careful to preserve the left-to-right evaluation order of arguments in the translated code. - use `TreeInfo.Applied` and `Type#paramss` from `reflect.internal` to get the info we need - use the parameter name for the lifted argument val, rather than `argN` - encapsulate handling of by-name-ness and parameter names in `mapArgumentss` - test for evaluation order preservation --- src/main/scala/scala/async/AnfTransform.scala | 39 ++++------- src/main/scala/scala/async/TransformUtils.scala | 70 +++++++++++++++----- src/test/scala/scala/async/TreeInterrogation.scala | 17 ++--- .../scala/async/run/anf/AnfTransformSpec.scala | 76 +++++++++++++++++----- 4 files changed, 132 insertions(+), 70 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index c5fbfd7..82af3c6 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -187,31 +187,31 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { def containsAwait = tree exists isAwait + tree match { case Select(qual, sel) if containsAwait => val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case Apply(fun, args) if containsAwait => - checkForAwaitInNonPrimaryParamSection(fun, args) - + case utils.Applied(fun, targs, argss @ (args :: rest)) if containsAwait => // we an assume that no await call appears in a by-name argument position, // this has already been checked. - val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val (argStats, argExprs): (List[List[Tree]], List[Tree]) = - mapArguments[List[Tree]](args) { - case (arg, i) if isByName(i) || isSafeToInline(arg) => (Nil, arg) - case (arg@Ident(name), _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline` - case (arg, i) => - inline.transformToList(arg) match { + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case arg if arg.isByName || isSafeToInline(arg.expr) => (Nil, arg.expr) + case Arg(arg@Ident(name), _, _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline` + case arg => + inline.transformToList(arg.expr) match { case stats :+ expr => - val valDef = defineVal(name.arg(i), expr, arg.pos) + val valDef = defineVal(arg.argName, expr, arg.expr.pos) (stats :+ valDef, Ident(valDef.name)) } } - funStats ++ argStats.flatten :+ attachCopy(tree)(Apply(simpleFun, argExprs).setSymbol(tree.symbol)) + val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs) + val newApply = argExprss.foldLeft(core)(Apply(_, _).setSymbol(tree.symbol)) + funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply) case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) @@ -273,19 +273,4 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } - - def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) { - // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite - // *all* argument lists in the correct order to preserve semantics. - fun match { - case Apply(fun1, _) => - fun1.tpe match { - case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe => - c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.") - case _ => - } - case _ => - } - - } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 38c33a4..239bea1 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -32,8 +32,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val await = "await" val bindSuffix = "$bind" - def arg(i: Int) = "arg" + i - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") @@ -102,11 +100,13 @@ private[async] final case class TransformUtils[C <: Context](c: C) { case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Apply(fun, args) => + case Applied(fun, targs, argss @ (_ :: _)) => val isInByName = isByName(fun) - for ((arg, index) <- args.zipWithIndex) { - if (!isInByName(index)) traverse(arg) - else byNameArgument(arg) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } } traverse(fun) case _ => super.traverse(tree) @@ -122,13 +122,31 @@ private[async] final case class TransformUtils[C <: Context](c: C) { Set(Boolean_&&, Boolean_||) } - def isByName(fun: Tree): (Int => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) i => true - else fun.tpe match { - case MethodType(params, _) => - val isByNameParams = params.map(_.asTerm.isByNameParam) - (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) - case _ => Map() + def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + def argName(fun: Tree): ((Int, Int) => String) = { + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + object Applied { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + object treeInfo extends { + val global: symtab.type = symtab + } with reflect.internal.TreeInfo + + def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = { + val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree] + Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]])) } } @@ -302,7 +320,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } - def isSafeToInline(tree: Tree) = { val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] object treeInfo extends { @@ -322,7 +339,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { * @param f A function from argument (with '_*' unwrapped) and argument index to argument. * @tparam A The type of the auxillary result */ - def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { args match { case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip @@ -332,4 +349,27 @@ private[async] final case class TransformUtils[C <: Context](c: C) { args.zipWithIndex.map(f.tupled).unzip } } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 4d611e5..deaee03 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -71,17 +71,14 @@ object TreeInterrogation extends App { val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten") import scala.async.Async._ val tree = tb.parse( - """ import scala.async.AsyncId._ - | async { - | val x = 1 - | val opt = Some("") - | await(0) - | val o @ Some(y) = opt - | - | { - | val o @ Some(y) = Some(".") - | } + """ import _root_.scala.async.AsyncId.{async, await} + | def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" + | val res = async { + | var i = 0 + | def get = async {i += 1; i} + | foo[Int](await(get))(await(get) :: Nil : _*) | } + | res | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 41c13e0..7be6299 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -232,28 +232,68 @@ class AnfTransformSpec { } @Test - def awaitNotAllowedInNonPrimaryParamSection1() { - expectError("implementation restriction: await may only be used in the first parameter list.") { - """ - | import _root_.scala.async.AsyncId.{async, await} - | def foo(primary: Any)(i: Int) = i - | async { - | foo(???)(await(0)) - | } - """.stripMargin + def awaitInNonPrimaryParamSection1() { + import _root_.scala.async.AsyncId.{async, await} + def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0" + val res = async { + var i = 0 + def get = {i += 1; i} + foo(get)(get) + } + res mustBe "a0 = 1, b0 = 2" + } + + @Test + def awaitInNonPrimaryParamSection2() { + import _root_.scala.async.AsyncId.{async, await} + def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" + val res = async { + var i = 0 + def get = async {i += 1; i} + foo[Int](await(get))(await(get) :: await(async(Nil)) : _*) + } + res mustBe "a0 = 1, b0 = 2" + } + + @Test + def awaitInNonPrimaryParamSectionWithLazy1() { + import _root_.scala.async.AsyncId.{async, await} + def foo[T](a: => Int)(b: Int) = b + val res = async { + def get = async {0} + foo[Int](???)(await(get)) } + res mustBe 0 } @Test - def awaitNotAllowedInNonPrimaryParamSection2() { - expectError("implementation restriction: await may only be used in the first parameter list.") { - """ - | import _root_.scala.async.AsyncId.{async, await} - | def foo[T](primary: Any)(i: Int) = i - | async { - | foo[Int](???)(await(0)) - | } - """.stripMargin + def awaitInNonPrimaryParamSectionWithLazy2() { + import _root_.scala.async.AsyncId.{async, await} + def foo[T](a: Int)(b: => Int) = a + val res = async { + def get = async {0} + foo[Int](await(get))(???) + } + res mustBe 0 + } + + @Test + def awaitWithLazy() { + import _root_.scala.async.AsyncId.{async, await} + def foo[T](a: Int, b: => Int) = a + val res = async { + def get = async {0} + foo[Int](await(get), ???) + } + res mustBe 0 + } + + @Test + def awaitOkInReciever() { + import scala.async.AsyncId.{async, await} + class Foo { def bar(a: Int)(b: Int) = a + b } + async { + await(async(new Foo)).bar(1)(2) } } -- cgit v1.2.3 From b38f991ab4948f3358a937604dc28ffa4901270e Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 15 Apr 2013 13:22:47 +0200 Subject: Rephrase a few pattern matches, fix ANF tracing. Addresses review comments --- src/main/scala/scala/async/AnfTransform.scala | 18 ++++++++---------- src/main/scala/scala/async/TransformUtils.scala | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 82af3c6..da375a5 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -107,9 +107,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { indent += 1 def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${ - indentString - }$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -193,19 +191,19 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case utils.Applied(fun, targs, argss @ (args :: rest)) if containsAwait => + case Applied(fun, targs, argss) if argss.nonEmpty && containsAwait => // we an assume that no await call appears in a by-name argument position, // this has already been checked. val funStats :+ simpleFun = inline.transformToList(fun) def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = mapArgumentss[List[Tree]](fun, argss) { - case arg if arg.isByName || isSafeToInline(arg.expr) => (Nil, arg.expr) - case Arg(arg@Ident(name), _, _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline` - case arg => - inline.transformToList(arg.expr) match { - case stats :+ expr => - val valDef = defineVal(arg.argName, expr, arg.expr.pos) + case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr) + case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline` + case Arg(expr, _, argName) => + inline.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr.pos) (stats :+ valDef, Ident(valDef.name)) } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 239bea1..7731b83 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -100,7 +100,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Applied(fun, targs, argss @ (_ :: _)) => + case Applied(fun, targs, argss) if argss.nonEmpty => val isInByName = isByName(fun) for ((args, i) <- argss.zipWithIndex) { for ((arg, j) <- args.zipWithIndex) { -- cgit v1.2.3