aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2013-04-17 05:35:14 -0700
committerPhilipp Haller <hallerp@gmail.com>2013-04-17 05:35:14 -0700
commit757930193a0a2225f4c5f35f449615a81d3aaf19 (patch)
tree3ffc5eb7a469e196ae3e7f74072c6d4e96ba7dcd
parentc26ba9db06a88dcd384e7dd1f0450206eef6f064 (diff)
parentb38f991ab4948f3358a937604dc28ffa4901270e (diff)
downloadscala-async-757930193a0a2225f4c5f35f449615a81d3aaf19.tar.gz
scala-async-757930193a0a2225f4c5f35f449615a81d3aaf19.tar.bz2
scala-async-757930193a0a2225f4c5f35f449615a81d3aaf19.zip
Merge pull request #9 from retronym/ticket/4-multi-param
Allow await in applications with multiple argument lists
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala45
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala70
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala17
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala76
4 files changed, 134 insertions, 74 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index c5fbfd7..da375a5 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -107,9 +107,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
indent += 1
def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
- AsyncUtils.trace(s"${
- indentString
- }$prefix(${oneLine(args)})")
+ AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
@@ -187,31 +185,31 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) {
def containsAwait = tree exists isAwait
+
tree match {
case Select(qual, sel) if containsAwait =>
val stats :+ expr = inline.transformToList(qual)
stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))
- case Apply(fun, args) if containsAwait =>
- checkForAwaitInNonPrimaryParamSection(fun, args)
-
+ case Applied(fun, targs, argss) if argss.nonEmpty && containsAwait =>
// we an assume that no await call appears in a by-name argument position,
// this has already been checked.
- val isByName: (Int) => Boolean = utils.isByName(fun)
val funStats :+ simpleFun = inline.transformToList(fun)
def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$")
- val (argStats, argExprs): (List[List[Tree]], List[Tree]) =
- mapArguments[List[Tree]](args) {
- case (arg, i) if isByName(i) || isSafeToInline(arg) => (Nil, arg)
- case (arg@Ident(name), _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline`
- case (arg, i) =>
- inline.transformToList(arg) match {
- case stats :+ expr =>
- val valDef = defineVal(name.arg(i), expr, arg.pos)
+ val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
+ mapArgumentss[List[Tree]](fun, argss) {
+ case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr)
+ case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline`
+ case Arg(expr, _, argName) =>
+ inline.transformToList(expr) match {
+ case stats :+ expr1 =>
+ val valDef = defineVal(argName, expr1, expr.pos)
(stats :+ valDef, Ident(valDef.name))
}
}
- funStats ++ argStats.flatten :+ attachCopy(tree)(Apply(simpleFun, argExprs).setSymbol(tree.symbol))
+ val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs)
+ val newApply = argExprss.foldLeft(core)(Apply(_, _).setSymbol(tree.symbol))
+ funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply)
case Block(stats, expr) if containsAwait =>
inline.transformToList(stats :+ expr)
@@ -273,19 +271,4 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
}
}
}
-
- def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) {
- // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite
- // *all* argument lists in the correct order to preserve semantics.
- fun match {
- case Apply(fun1, _) =>
- fun1.tpe match {
- case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe =>
- c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.")
- case _ =>
- }
- case _ =>
- }
-
- }
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 090a334..ebd546f 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -32,8 +32,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val await = "await"
val bindSuffix = "$bind"
- def arg(i: Int) = "arg" + i
-
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$")
@@ -102,11 +100,13 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
case dd: DefDef => nestedMethod(dd)
case fun: Function => function(fun)
case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
- case Apply(fun, args) =>
+ case Applied(fun, targs, argss) if argss.nonEmpty =>
val isInByName = isByName(fun)
- for ((arg, index) <- args.zipWithIndex) {
- if (!isInByName(index)) traverse(arg)
- else byNameArgument(arg)
+ for ((args, i) <- argss.zipWithIndex) {
+ for ((arg, j) <- args.zipWithIndex) {
+ if (!isInByName(i, j)) traverse(arg)
+ else byNameArgument(arg)
+ }
}
traverse(fun)
case _ => super.traverse(tree)
@@ -122,13 +122,31 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
Set(Boolean_&&, Boolean_||)
}
- def isByName(fun: Tree): (Int => Boolean) = {
- if (Boolean_ShortCircuits contains fun.symbol) i => true
- else fun.tpe match {
- case MethodType(params, _) =>
- val isByNameParams = params.map(_.asTerm.isByNameParam)
- (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
- case _ => Map()
+ def isByName(fun: Tree): ((Int, Int) => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
+ else {
+ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
+ val byNamess = paramss.map(_.map(_.isByNameParam))
+ (i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
+ }
+ }
+ def argName(fun: Tree): ((Int, Int) => String) = {
+ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
+ val namess = paramss.map(_.map(_.name.toString))
+ (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
+ }
+
+ object Applied {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ object treeInfo extends {
+ val global: symtab.type = symtab
+ } with reflect.internal.TreeInfo
+
+ def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = {
+ val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree]
+ Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]]))
}
}
@@ -301,7 +319,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
}
}
-
def isSafeToInline(tree: Tree) = {
val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
object treeInfo extends {
@@ -321,7 +338,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
* @param f A function from argument (with '_*' unwrapped) and argument index to argument.
* @tparam A The type of the auxillary result
*/
- def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
+ private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
args match {
case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) =>
val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip
@@ -331,4 +348,27 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
args.zipWithIndex.map(f.tupled).unzip
}
}
+
+ case class Arg(expr: Tree, isByName: Boolean, argName: String)
+
+ /**
+ * Transform a list of argument lists, producing the transformed lists, and lists of auxillary
+ * results.
+ *
+ * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will
+ * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`.
+ *
+ * @param fun The function being applied
+ * @param argss The argument lists
+ * @return (auxillary results, mapped argument trees)
+ */
+ def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = {
+ val isByNamess: (Int, Int) => Boolean = isByName(fun)
+ val argNamess: (Int, Int) => String = argName(fun)
+ argss.zipWithIndex.map { case (args, i) =>
+ mapArguments[A](args) {
+ (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j)))
+ }
+ }.unzip
+ }
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 4d611e5..deaee03 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -71,17 +71,14 @@ object TreeInterrogation extends App {
val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten")
import scala.async.Async._
val tree = tb.parse(
- """ import scala.async.AsyncId._
- | async {
- | val x = 1
- | val opt = Some("")
- | await(0)
- | val o @ Some(y) = opt
- |
- | {
- | val o @ Some(y) = Some(".")
- | }
+ """ import _root_.scala.async.AsyncId.{async, await}
+ | def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
+ | val res = async {
+ | var i = 0
+ | def get = async {i += 1; i}
+ | foo[Int](await(get))(await(get) :: Nil : _*)
| }
+ | res
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 41c13e0..7be6299 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -232,28 +232,68 @@ class AnfTransformSpec {
}
@Test
- def awaitNotAllowedInNonPrimaryParamSection1() {
- expectError("implementation restriction: await may only be used in the first parameter list.") {
- """
- | import _root_.scala.async.AsyncId.{async, await}
- | def foo(primary: Any)(i: Int) = i
- | async {
- | foo(???)(await(0))
- | }
- """.stripMargin
+ def awaitInNonPrimaryParamSection1() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0"
+ val res = async {
+ var i = 0
+ def get = {i += 1; i}
+ foo(get)(get)
+ }
+ res mustBe "a0 = 1, b0 = 2"
+ }
+
+ @Test
+ def awaitInNonPrimaryParamSection2() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
+ val res = async {
+ var i = 0
+ def get = async {i += 1; i}
+ foo[Int](await(get))(await(get) :: await(async(Nil)) : _*)
+ }
+ res mustBe "a0 = 1, b0 = 2"
+ }
+
+ @Test
+ def awaitInNonPrimaryParamSectionWithLazy1() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo[T](a: => Int)(b: Int) = b
+ val res = async {
+ def get = async {0}
+ foo[Int](???)(await(get))
}
+ res mustBe 0
}
@Test
- def awaitNotAllowedInNonPrimaryParamSection2() {
- expectError("implementation restriction: await may only be used in the first parameter list.") {
- """
- | import _root_.scala.async.AsyncId.{async, await}
- | def foo[T](primary: Any)(i: Int) = i
- | async {
- | foo[Int](???)(await(0))
- | }
- """.stripMargin
+ def awaitInNonPrimaryParamSectionWithLazy2() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo[T](a: Int)(b: => Int) = a
+ val res = async {
+ def get = async {0}
+ foo[Int](await(get))(???)
+ }
+ res mustBe 0
+ }
+
+ @Test
+ def awaitWithLazy() {
+ import _root_.scala.async.AsyncId.{async, await}
+ def foo[T](a: Int, b: => Int) = a
+ val res = async {
+ def get = async {0}
+ foo[Int](await(get), ???)
+ }
+ res mustBe 0
+ }
+
+ @Test
+ def awaitOkInReciever() {
+ import scala.async.AsyncId.{async, await}
+ class Foo { def bar(a: Int)(b: Int) = a + b }
+ async {
+ await(async(new Foo)).bar(1)(2)
}
}