diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-27 17:42:40 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-27 17:44:38 +0100 |
commit | c79d8c07d194aeaa0565f2c136b8308519d59015 (patch) | |
tree | 01cdb87f6ae7d67921994e0540035462122c1aa3 /src | |
parent | 5ee41166ea1684672ddb9a0d605a664661ba5f47 (diff) | |
download | scala-async-c79d8c07d194aeaa0565f2c136b8308519d59015.tar.gz scala-async-c79d8c07d194aeaa0565f2c136b8308519d59015.tar.bz2 scala-async-c79d8c07d194aeaa0565f2c136b8308519d59015.zip |
Fix ANF transform involving `xs: _*` trees.
We need to unwrap and inline `xs`, then
rewrap the result expression with the wildcard
star.
Addresses the first half of #46.
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 22 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 21 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 22 |
3 files changed, 54 insertions, 11 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 2fa96c9..afcf6bd 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -195,18 +195,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val argLists: List[List[Tree]] = args.zipWithIndex map { - case (arg, i) if isByName(i) || isSafeToInline(arg) => List(arg) - case (arg@Ident(name), _) if isAwaitRef(name) => List(arg) // not typed, so it eludes the check in `isSafeToInline` - case (arg, i) => inline.transformToList(arg) match { - case stats :+ expr => - val valDef = defineVal(name.arg(i), expr, arg.pos) - stats ::: List(valDef, Ident(valDef.name)) + val (argStats, argExprs): (List[List[Tree]], List[Tree]) = + mapArguments[List[Tree]](args) { + case (arg, i) if isByName(i) || isSafeToInline(arg) => (Nil, arg) + case (arg@Ident(name), _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline` + case (arg, i) => + inline.transformToList(arg) match { + case stats :+ expr => + val valDef = defineVal(name.arg(i), expr, arg.pos) + (stats :+ valDef, Ident(valDef.name)) + } } - } - val allArgStats = argLists flatMap (_.init) - val simpleArgs = argLists map (_.last) - funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)) + funStats ++ argStats.flatten :+ attachCopy(tree)(Apply(simpleFun, argExprs).setSymbol(tree.symbol)) case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 23f39d2..f780799 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -253,4 +253,25 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val castTree = tree.asInstanceOf[symtab.Tree] treeInfo.isExprSafeToInline(castTree) } + + /** Map a list of arguments to: + * - A list of argument Trees + * - A list of auxillary results. + * + * The function unwraps and rewraps the `arg :_*` construct. + * + * @param args The original argument trees + * @param f A function from argument (with '_*' unwrapped) and argument index to argument. + * @tparam A The type of the auxillary result + */ + def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + args match { + case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => + val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip + val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) + (a, exprs) + case args => + args.zipWithIndex.map(f.tupled).unzip + } + } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 529386b..41c13e0 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -289,4 +289,26 @@ class AnfTransformSpec { foo(a = await(next())) } mustBe ((1, 2)) } + + @Test + def repeatedParams1() { + import scala.async.AsyncId.{async, await} + var i = 0 + def foo(a: Int, b: Int*) = b.toList + def id(i: Int) = i + async { + foo(await(0), id(1), id(2), id(3), await(4)) + } mustBe (List(1, 2, 3, 4)) + } + + @Test + def repeatedParams2() { + import scala.async.AsyncId.{async, await} + var i = 0 + def foo(a: Int, b: Int*) = b.toList + def id(i: Int) = i + async { + foo(await(0), List(id(1), id(2), id(3)): _*) + } mustBe (List(1, 2, 3)) + } } |