diff options
Diffstat (limited to 'src/main/scala/scala')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 45 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 70 |
2 files changed, 69 insertions, 46 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index c5fbfd7..da375a5 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -107,9 +107,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { indent += 1 def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${ - indentString - }$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -187,31 +185,31 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { def containsAwait = tree exists isAwait + tree match { case Select(qual, sel) if containsAwait => val stats :+ expr = inline.transformToList(qual) stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - case Apply(fun, args) if containsAwait => - checkForAwaitInNonPrimaryParamSection(fun, args) - + case Applied(fun, targs, argss) if argss.nonEmpty && containsAwait => // we an assume that no await call appears in a by-name argument position, // this has already been checked. - val isByName: (Int) => Boolean = utils.isByName(fun) val funStats :+ simpleFun = inline.transformToList(fun) def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val (argStats, argExprs): (List[List[Tree]], List[Tree]) = - mapArguments[List[Tree]](args) { - case (arg, i) if isByName(i) || isSafeToInline(arg) => (Nil, arg) - case (arg@Ident(name), _) if isAwaitRef(name) => (Nil, arg) // not typed, so it eludes the check in `isSafeToInline` - case (arg, i) => - inline.transformToList(arg) match { - case stats :+ expr => - val valDef = defineVal(name.arg(i), expr, arg.pos) + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr) + case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline` + case Arg(expr, _, argName) => + inline.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr.pos) (stats :+ valDef, Ident(valDef.name)) } } - funStats ++ argStats.flatten :+ attachCopy(tree)(Apply(simpleFun, argExprs).setSymbol(tree.symbol)) + val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs) + val newApply = argExprss.foldLeft(core)(Apply(_, _).setSymbol(tree.symbol)) + funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply) case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) @@ -273,19 +271,4 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } - - def checkForAwaitInNonPrimaryParamSection(fun: Tree, args: List[Tree]) { - // TODO treat the Apply(Apply(.., argsN), ...), args0) holistically, and rewrite - // *all* argument lists in the correct order to preserve semantics. - fun match { - case Apply(fun1, _) => - fun1.tpe match { - case MethodType(_, resultType: MethodType) if resultType =:= fun.tpe => - c.error(fun.pos, "implementation restriction: await may only be used in the first parameter list.") - case _ => - } - case _ => - } - - } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 090a334..ebd546f 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -32,8 +32,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val await = "await" val bindSuffix = "$bind" - def arg(i: Int) = "arg" + i - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") @@ -102,11 +100,13 @@ private[async] final case class TransformUtils[C <: Context](c: C) { case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Apply(fun, args) => + case Applied(fun, targs, argss) if argss.nonEmpty => val isInByName = isByName(fun) - for ((arg, index) <- args.zipWithIndex) { - if (!isInByName(index)) traverse(arg) - else byNameArgument(arg) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } } traverse(fun) case _ => super.traverse(tree) @@ -122,13 +122,31 @@ private[async] final case class TransformUtils[C <: Context](c: C) { Set(Boolean_&&, Boolean_||) } - def isByName(fun: Tree): (Int => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) i => true - else fun.tpe match { - case MethodType(params, _) => - val isByNameParams = params.map(_.asTerm.isByNameParam) - (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) - case _ => Map() + def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + def argName(fun: Tree): ((Int, Int) => String) = { + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + object Applied { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + object treeInfo extends { + val global: symtab.type = symtab + } with reflect.internal.TreeInfo + + def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = { + val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree] + Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]])) } } @@ -301,7 +319,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } - def isSafeToInline(tree: Tree) = { val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] object treeInfo extends { @@ -321,7 +338,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { * @param f A function from argument (with '_*' unwrapped) and argument index to argument. * @tparam A The type of the auxillary result */ - def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { args match { case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip @@ -331,4 +348,27 @@ private[async] final case class TransformUtils[C <: Context](c: C) { args.zipWithIndex.map(f.tupled).unzip } } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } } |