diff options
author | Philipp Haller <hallerp@gmail.com> | 2013-08-14 07:51:10 -0700 |
---|---|---|
committer | Philipp Haller <hallerp@gmail.com> | 2013-08-14 07:51:10 -0700 |
commit | 0c5a1ea043c72bbc3568a8df7f75bc65a261ed21 (patch) | |
tree | b2667b421b7dde7ee8e196f616661baddd307579 | |
parent | 9156cbeb944db80245766c317f43434b4c1981e5 (diff) | |
parent | b79c9ad864a27aea620254c3eade6d38adcf38f2 (diff) | |
download | scala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.tar.gz scala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.tar.bz2 scala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.zip |
Merge pull request #27 from retronym/topic/typed-transform-2
Typeful transformations
38 files changed, 1629 insertions, 1391 deletions
@@ -3,3 +3,4 @@ target .idea .idea_modules *.icode +project/local.sbt
\ No newline at end of file @@ -1,4 +1,4 @@ -scalaVersion := "2.10.1" +scalaVersion := "2.10.2" organization := "org.typesafe.async" // TODO new org name under scala-lang. @@ -8,8 +8,8 @@ version := "1.0.0-SNAPSHOT" libraryDependencies <++= (scalaVersion) { sv => Seq( - "org.scala-lang" % "scala-reflect" % sv, - "org.scala-lang" % "scala-compiler" % sv % "test" + "org.scala-lang" % "scala-reflect" % sv % "provided", + "org.scala-lang" % "scala-compiler" % sv % "provided" ) } @@ -32,6 +32,8 @@ scalacOptions += "-P:continuations:enable" scalacOptions ++= Seq("-deprecation", "-unchecked", "-Xlint", "-feature") +scalacOptions in Test ++= Seq("-Yrangepos") + description := "An asynchronous programming facility for Scala, in the spirit of C# await/async" homepage := Some(url("http://github.com/scala/async")) @@ -40,6 +42,9 @@ startYear := Some(2012) licenses +=("Scala license", url("https://github.com/scala/async/blob/master/LICENSE")) +// Uncomment to disable test compilation. +// (sources in Test) ~= ((xs: Seq[File]) => xs.filter(f => Seq("TreeInterrogation", "package").exists(f.name.contains))) + pomExtra := ( <developers> <developer> diff --git a/project/build.properties b/project/build.properties index 2b9d40c..5e96e96 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=0.12.1
\ No newline at end of file +sbt.version=0.12.4 diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala deleted file mode 100644 index 5b9901d..0000000 --- a/src/main/scala/scala/async/AnfTransform.scala +++ /dev/null @@ -1,275 +0,0 @@ - -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async - -import scala.reflect.macros.Context - -private[async] final case class AnfTransform[C <: Context](c: C) { - - import c.universe._ - - val utils = TransformUtils[c.type](c) - - import utils._ - - def apply(tree: Tree): List[Tree] = { - val unique = uniqueNames(tree) - // Must prepend the () for issue #31. - anf.transformToList(Block(List(c.literalUnit.tree), unique)) - } - - private def uniqueNames(tree: Tree): Tree = { - new UniqueNames(tree).transform(tree) - } - - /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name. - * Only modifies names that appear more than once in the tree. - * - * This step is needed to allow us to safely merge blocks during the `inline` transform below. - */ - private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Symbol] = { - class DuplicateNameTraverser extends AsyncTraverser { - val result = collection.mutable.Buffer[Symbol]() - - override def traverse(tree: Tree) { - tree match { - case dt: DefTree => result += dt.symbol - case _ => super.traverse(tree) - } - } - } - val dupNameTraverser = new DuplicateNameTraverser - dupNameTraverser.traverse(tree) - dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol] - } - - /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - - val renamed = collection.mutable.Set[Symbol]() - - override def transform(tree: Tree): Tree = { - tree match { - case defTree: DefTree if repeatedNames(defTree.symbol) => - val trans = super.transform(defTree) - val origName = defTree.symbol.name - val sym = defTree.symbol.asInstanceOf[symtab.Symbol] - val fresh = name.fresh(sym.name.toString) - sym.name = origName match { - case _: TermName => symtab.newTermName(fresh) - case _: TypeName => symtab.newTypeName(fresh) - } - renamed += trans.symbol - val newName = trans.symbol.name - trans match { - case ValDef(mods, name, tpt, rhs) => - treeCopy.ValDef(trans, mods, newName, tpt, rhs) - case Bind(name, body) => - treeCopy.Bind(trans, newName, body) - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => - treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) - case TypeDef(mods, name, tparams, rhs) => - treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs)) - // If we were to allow local classes / objects, we would need to rename here. - case ClassDef(mods, name, tparams, impl) => - treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template]) - case ModuleDef(mods, name, impl) => - treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) - case x => super.transform(x) - } - case Ident(name) => - if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) - else tree - case Select(fun, name) => - if (renamed(tree.symbol)) { - treeCopy.Select(tree, transform(fun), tree.symbol.name) - } else super.transform(tree) - case tt: TypeTree => - val tt1 = tt.asInstanceOf[symtab.TypeTree] - val orig = tt1.original - if (orig != null) tt1.setOriginal(transform(orig.asInstanceOf[Tree]).asInstanceOf[symtab.Tree]) - super.transform(tt) - case _ => super.transform(tree) - } - } - } - - private object trace { - private var indent = -1 - - def indentString = " " * indent - - def apply[T](prefix: String, args: Any)(t: => T): T = { - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) - try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") - val result = t - AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") - result - } finally { - indent -= 1 - } - } - } - - private object inline { - def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { - val stats :+ expr = anf.transformToList(tree) - expr match { - case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await, expr, tree.pos) - stats :+ valDef :+ Ident(valDef.name) - - case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) - case _ => Assign(Ident(varDef.name), orig) - } - val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) - } - - case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => - attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))) - case cd@CaseDef(pat, guard, body) => - attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body))) - } - val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign)) - stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) - } - case _ => - stats :+ expr - } - } - - def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList - - def transformToBlock(tree: Tree): Block = transformToList(tree) match { - case stats :+ expr => Block(stats, expr) - } - - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp)) - vd.setPos(pos) - vd - } - } - - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) - vd.setPos(pos) - vd - } - - private object anf { - - private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { - List(tree) - } else tree match { - case Select(qual, sel) => - val stats :+ expr = inline.transformToList(qual) - stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - - case Applied(fun, targs, argss) if argss.nonEmpty => - // we an assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = inline.transformToList(fun) - def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = - mapArgumentss[List[Tree]](fun, argss) { - case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr) - case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline` - case Arg(expr, _, argName) => - inline.transformToList(expr) match { - case stats :+ expr1 => - val valDef = defineVal(argName, expr1, expr.pos) - (stats :+ valDef, Ident(valDef.name)) - } - } - val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs) - val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) - funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply) - case Block(stats, expr) => - inline.transformToList(stats :+ expr) - - case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)) - } else List(tree) - - case Assign(lhs, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(Assign(lhs, expr)) - - case If(cond, thenp, elsep) => - val condStats :+ condExpr = inline.transformToList(cond) - val thenBlock = inline.transformToBlock(thenp) - val elseBlock = inline.transformToBlock(elsep) - // Typechecking with `condExpr` as the condition fails if the condition - // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems - // we rely on this call to `typeCheck` descending into the branches. - // But, we can get away with typechecking a throwaway `If` tree with the - // original scrutinee and the new branches, and setting that type on - // the real `If` tree. - val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe - condStats :+ - attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = inline.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - // extract local variables for all names bound in `pat`, and rewrite `body` - // to refer to these. - // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. - val block = inline.transformToBlock(body) - val (valDefs, mappings) = (pat collect { - case b@Bind(name, _) => - val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) - val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) - (vd, (b.symbol, newName)) - }).unzip - val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] - attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) - } - // Refer to comments the translation of `If` above. - val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe - val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) - scrutStats :+ typedMatch - - case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = inline.transformToList(fun) - funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol)) - - case _ => - List(tree) - } - } - } -} diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala deleted file mode 100644 index 35d3687..0000000 --- a/src/main/scala/scala/async/Async.scala +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async - -import scala.language.experimental.macros -import scala.reflect.macros.Context -import scala.reflect.internal.annotations.compileTimeOnly - -object Async extends AsyncBase { - - import scala.concurrent.Future - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) -} - -object AsyncId extends AsyncBase { - lazy val futureSystem = IdentityFutureSystem - type FS = IdentityFutureSystem.type - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body) -} - -/** - * A base class for the `async` macro. Subclasses must provide: - * - * - Concrete types for a given future system - * - Tree manipulations to create and complete the equivalent of Future and Promise - * in that system. - * - The `async` macro declaration itself, and a forwarder for the macro implementation. - * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) - * - * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. - */ -abstract class AsyncBase { - self => - - type FS <: FutureSystem - val futureSystem: FS - - /** - * A call to `await` must be nested in an enclosing `async` block. - * - * A call to `await` does not block the current thread, rather it is a delimiter - * used by the enclosing `async` macro. Code following the `await` - * call is executed asynchronously, when the argument of `await` has been completed. - * - * @param awaitable the future from which a value is awaited. - * @tparam T the type of that value. - * @return the value. - */ - @compileTimeOnly("`await` must be enclosed in an `async` block") - def await[T](awaitable: futureSystem.Fut[T]): T = ??? - - protected[async] def fallbackEnabled = false - - def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { - import c.universe._ - - val analyzer = AsyncAnalysis[c.type](c, this) - val utils = TransformUtils[c.type](c) - import utils.{name, defn} - - analyzer.reportUnsupportedAwaits(body.tree) - - // Transform to A-normal form: - // - no await calls in qualifiers or arguments, - // - if/match only used in statement position. - val anfTree: Block = { - val anf = AnfTransform[c.type](c) - val restored = utils.restorePatternMatchingFunctions(body.tree) - val stats1 :+ expr1 = anf(restored) - val block = Block(stats1, expr1) - c.typeCheck(block).asInstanceOf[Block] - } - - // Analyze the block to find locals that will be accessed from multiple - // states of our generated state machine, e.g. a value assigned before - // an `await` and read afterwards. - val renameMap: Map[Symbol, TermName] = { - analyzer.defTreesUsedInSubsequentStates(anfTree).map { - vd => - (vd.symbol, name.fresh(vd.name.toTermName)) - }.toMap - } - - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) - import builder.futureSystemOps - val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) - import asyncBlock.asyncStates - logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) - - // Important to retain the original declaration order here! - val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => - utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => - DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) - } - - val onCompleteHandler = { - Function( - List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), - asyncBlock.onCompleteHandler) - } - val resumeFunTree = asyncBlock.resumeFunTree[T] - - val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) - - lazy val stateMachine: ClassDef = { - val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) - val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) - val applyDefDef: DefDef = { - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody) - } - val apply0DefDef: DefDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) - } - List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) - } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } - ClassDef(NoMods, name.stateMachineT, Nil, template) - } - - def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - - val code: c.Expr[futureSystem.Fut[T]] = { - val isSimple = asyncStates.size == 1 - val tree = - if (isSimple) - Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` - else { - Block(List[Tree]( - stateMachine, - ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)), - futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil)) - ), - futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) - } - c.Expr[futureSystem.Fut[T]](tree) - } - - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") - code - } - - def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { - def location = try { - c.macroApplication.pos.source.path - } catch { - case _: UnsupportedOperationException => - c.macroApplication.pos.toString - } - - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") - states foreach (s => AsyncUtils.vprintln(s)) - } -} - -/** Internal class used by the `async` macro; should not be manually extended by client code */ -abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { - def result$async: Result - - def execContext$async: EC -} diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala deleted file mode 100644 index 4f55f1b..0000000 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - -package scala.async - -import scala.reflect.macros.Context -import scala.collection.mutable - -private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) { - import c.universe._ - - val utils = TransformUtils[c.type](c) - - import utils._ - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - * Must be called on the original tree, not on the ANF transformed tree. - */ - def reportUnsupportedAwaits(tree: Tree): Boolean = { - val analyzer = new UnsupportedAwaitAnalyzer - analyzer.traverse(tree) - analyzer.hasUnsupportedAwaits - } - - /** - * Analyze the contents of an `async` block in order to: - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - * - * Must be called on the ANF transformed tree. - */ - def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { - val analyzer = new AsyncDefinitionUseAnalyzer - analyzer.traverse(tree) - val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct - liftable - } - - private class UnsupportedAwaitAnalyzer extends AsyncTraverser { - var hasUnsupportedAwaits = false - - override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - if (!reportUnsupportedAwait(classDef, s"nested $kind")) { - // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - if (classDef.symbol.asClass.isCaseClass) - c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") - } - } - - override def nestedModule(module: ModuleDef) { - if (!reportUnsupportedAwait(module, "nested object")) { - // local object definitions lead to spurious type errors (because of resetAllAttrs?) - c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") - } - } - - override def nestedMethod(module: DefDef) { - reportUnsupportedAwait(module, "nested method") - } - - override def byNameArgument(arg: Tree) { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function) { - reportUnsupportedAwait(function, "nested function") - } - - override def patMatFunction(tree: Match) { - reportUnsupportedAwait(tree, "nested function") - } - - override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait - tree match { - case Try(_, _, _) if containsAwait => - reportUnsupportedAwait(tree, "try/catch") - super.traverse(tree) - case Return(_) => - c.abort(tree.pos, "return is illegal within a async block") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => - c.abort(tree.pos, "lazy vals are illegal within an async block") - case _ => - super.traverse(tree) - } - } - - /** - * @return true, if the tree contained an unsupported await. - */ - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { - val badAwaits: List[RefTree] = tree collect { - case rt: RefTree if isAwait(rt) => rt - } - badAwaits foreach { - tree => - reportError(tree.pos, s"await must not be used under a $whyUnsupported.") - } - badAwaits.nonEmpty - } - - private def reportError(pos: Position, msg: String) { - hasUnsupportedAwaits = true - if (!asyncBase.fallbackEnabled) - c.error(pos, msg) - } - } - - private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { - private var chunkId = 0 - - private def nextChunk() = chunkId += 1 - - private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - - val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() - val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() - - override def nestedMethod(defDef: DefDef) { - nestedMethodsToLift += defDef - markReferencedVals(defDef) - } - - override def function(function: Function) { - markReferencedVals(function) - } - - override def patMatFunction(tree: Match) { - markReferencedVals(tree) - } - - private def markReferencedVals(tree: Tree) { - tree foreach { - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) => - valDefsToLift += vd // lift all vals referred to by nested functions. - case _ => - } - case _ => - } - } - - override def traverse(tree: Tree) = { - tree match { - case If(cond, thenp, elsep) if tree exists isAwait => - traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => - traverseChunks(selector :: cases) - case LabelDef(name, params, rhs) if rhs exists isAwait => - traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => - super.traverse(tree) - nextChunk() - case vd: ValDef => - super.traverse(tree) - valDefChunkId += (vd.symbol -> (vd -> chunkId)) - val isPatternBinder = vd.name.toString.contains(name.bindSuffix) - if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd - case as: Assign => - if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) - - // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) - valDefsToLift += vd - } - super.traverse(tree) - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) if defChunkId != chunkId => - valDefsToLift += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) - } - } - - private def traverseChunks(trees: List[Tree]) { - trees.foreach { - t => traverse(t); nextChunk() - } - } - } - -} diff --git a/src/main/scala/scala/async/AsyncBase.scala b/src/main/scala/scala/async/AsyncBase.scala new file mode 100644 index 0000000..ff04a57 --- /dev/null +++ b/src/main/scala/scala/async/AsyncBase.scala @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async + +import scala.language.experimental.macros +import scala.reflect.macros.Context +import scala.concurrent.{Future, ExecutionContext} +import scala.async.internal.{AsyncBase, ScalaConcurrentFutureSystem} + +object Async extends AsyncBase { + type FS = ScalaConcurrentFutureSystem.type + val futureSystem: FS = ScalaConcurrentFutureSystem + + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } +} diff --git a/src/main/scala/scala/async/StateMachine.scala b/src/main/scala/scala/async/StateMachine.scala new file mode 100644 index 0000000..823df71 --- /dev/null +++ b/src/main/scala/scala/async/StateMachine.scala @@ -0,0 +1,12 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result: Result + + def execContext: EC +} diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala deleted file mode 100644 index ebd546f..0000000 --- a/src/main/scala/scala/async/TransformUtils.scala +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ -package scala.async - -import scala.reflect.macros.Context -import reflect.ClassTag - -/** - * Utilities used in both `ExprBuilder` and `AnfTransform`. - */ -private[async] final case class TransformUtils[C <: Context](c: C) { - - import c.universe._ - - object name { - def suffix(string: String) = string + "$async" - - def suffixedName(prefix: String) = newTermName(suffix(prefix)) - - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - val stateMachine = newTermName(fresh("stateMachine")) - val stateMachineT = stateMachine.toTypeName - val apply = newTermName("apply") - val applyOrElse = newTermName("applyOrElse") - val tr = newTermName("tr") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" - val bindSuffix = "$bind" - - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) - - def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") - } - - def defaultValue(tpe: Type): Literal = { - val defaultValue: Any = - if (tpe <:< definitions.BooleanTpe) false - else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 - else if (tpe <:< definitions.AnyValTpe) 0 - else null - Literal(Constant(defaultValue)) - } - - def isAwait(fun: Tree) = - fun.symbol == defn.Async_await - - /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node - * referring to the corresponding new name - */ - def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = { - val renamer = new Transformer { - override def transform(tree: Tree) = tree match { - case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_)) - case tt: TypeTree if tt.original != EmptyTree && tt.original != null => - // We also have to apply our renaming transform on originals of TypeTrees. - // TODO 2.10.1 Can we find a cleaner way? - val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val tt1 = tt.asInstanceOf[symTab.TypeTree] - tt1.setOriginal(transform(tt.original).asInstanceOf[symTab.Tree]) - super.transform(tree) - case _ => super.transform(tree) - } - } - renamer.transform(tree) - } - - /** Descends into the regions of the tree that are subject to the - * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descent stops, - * and `nestedClass` etc are invoked. - */ - trait AsyncTraverser extends Traverser { - def nestedClass(classDef: ClassDef) { - } - - def nestedModule(module: ModuleDef) { - } - - def nestedMethod(module: DefDef) { - } - - def byNameArgument(arg: Tree) { - } - - def function(function: Function) { - } - - def patMatFunction(tree: Match) { - } - - override def traverse(tree: Tree) { - tree match { - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Applied(fun, targs, argss) if argss.nonEmpty => - val isInByName = isByName(fun) - for ((args, i) <- argss.zipWithIndex) { - for ((arg, j) <- args.zipWithIndex) { - if (!isInByName(i, j)) traverse(arg) - else byNameArgument(arg) - } - } - traverse(fun) - case _ => super.traverse(tree) - } - } - } - - private lazy val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) - } - - def isByName(fun: Tree): ((Int, Int) => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true - else { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss - val byNamess = paramss.map(_.map(_.isByNameParam)) - (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) - } - } - def argName(fun: Tree): ((Int, Int) => String) = { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss - val namess = paramss.map(_.map(_.name.toString)) - (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") - } - - object Applied { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - - def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = { - val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree] - Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]])) - } - } - - def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { - case Block(stats, expr) => (stats, expr) - case _ => (List(tree), Literal(Constant(()))) - } - - def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) - } - - def emptyConstructor: DefDef = { - val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) - DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree)) - } - - def applied(className: String, types: List[Type]): AppliedTypeTree = - AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_))) - - object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) - - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - val Try_isFailure = methodSym(reify((null: scala.util.Try[Any]).isFailure)) - - val TryClass = c.mirror.staticClass("scala.util.Try") - val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - - private def asyncMember(name: String) = { - val asyncMod = c.mirror.staticClass("scala.async.AsyncBase") - val tpe = asyncMod.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) - } - - val Async_await = asyncMember("await") - } - - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ - private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) - tree2.collect { - case s: SymTree if s.symbol.isMethod => s.symbol - }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) - } - - /** - * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like: - * we don't want to copy types and symbols to the new trees in some cases. - * - * Instead, we just copy positions and attachments. - */ - def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = { - tree.setPos(orig.pos) - for (att <- orig.attachments.all) - tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass)) - tree - } - - def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) = - new ResetInternalAttrs(internalSyms.toSet).transform(tree) - - /** - * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]] - * - * A transformer which resets symbol and tpe fields of all nodes in a given tree, - * with special treatment of: - * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset - * to empty if it started out empty or refers to local symbols (which are erased). - * `TypeApply` nodes: are deleted if type arguments end up reverted to empty - * - * `This` and `Ident` nodes referring to an external symbol are ''not'' reset. - */ - private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer { - - import language.existentials - - override def transform(tree: Tree): Tree = super.transform { - def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol) - - tree match { - case tpt: TypeTree => resetTypeTree(tpt) - case TypeApply(fn, args) - if args map transform exists (_.isEmpty) => transform(fn) - case EmptyTree => tree - case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block - case _ => resetTree(tree) - } - } - - private def resetTypeTree(tpt: TypeTree): Tree = { - if (tpt.original != null) - transform(tpt.original) - else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) { - val dupl = tpt.duplicate - dupl.tpe = null - dupl - } - else tpt - } - - private def resetTree(tree: Tree): Tree = { - val hasSymbol: Boolean = { - val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}] - reflectInternalTree.hasSymbol - } - val dupl = tree.duplicate - if (hasSymbol) - dupl.symbol = NoSymbol - dupl.tpe = null - dupl - } - } - - /** - * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }` - * with `Match(EmptyTree, cases`. - * - * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees. - */ - // TODO Reference JIRA issue. - final def restorePatternMatchingFunctions(tree: Tree) = - RestorePatternMatchingFunctions transform tree - - private object RestorePatternMatchingFunctions extends Transformer { - - import language.existentials - val DefaultCaseName: TermName = "defaultCase$" - - override def transform(tree: Tree): Tree = { - val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet] - def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC - - /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know - * whether the user provided cases are exhaustive. */ - def isSyntheticDefaultCase(cdef: CaseDef) = cdef match { - case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true - case _ => false - } - tree match { - case Block( - (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil, - Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) => - val restored = (body collectFirst { - case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) => - val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef)) - val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]] - Match(EmptyTree, transformedCases) - }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body")) - restored - case t => super.transform(t) - } - } - } - - def isSafeToInline(tree: Tree) = { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - val castTree = tree.asInstanceOf[symtab.Tree] - treeInfo.isExprSafeToInline(castTree) - } - - /** Map a list of arguments to: - * - A list of argument Trees - * - A list of auxillary results. - * - * The function unwraps and rewraps the `arg :_*` construct. - * - * @param args The original argument trees - * @param f A function from argument (with '_*' unwrapped) and argument index to argument. - * @tparam A The type of the auxillary result - */ - private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { - args match { - case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => - val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip - val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) - (a, exprs) - case args => - args.zipWithIndex.map(f.tupled).unzip - } - } - - case class Arg(expr: Tree, isByName: Boolean, argName: String) - - /** - * Transform a list of argument lists, producing the transformed lists, and lists of auxillary - * results. - * - * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will - * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. - * - * @param fun The function being applied - * @param argss The argument lists - * @return (auxillary results, mapped argument trees) - */ - def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { - val isByNamess: (Int, Int) => Boolean = isByName(fun) - val argNamess: (Int, Int) => String = argName(fun) - argss.zipWithIndex.map { case (args, i) => - mapArguments[A](args) { - (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) - } - }.unzip - } -} diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala index a669cfa..1a6ac87 100644 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -9,8 +9,9 @@ import scala.language.experimental.macros import scala.reflect.macros.Context import scala.util.continuations._ +import scala.async.internal.{AsyncMacro, AsyncUtils} -trait AsyncBaseWithCPSFallback extends AsyncBase { +trait AsyncBaseWithCPSFallback extends internal.AsyncBase { /* Fall-back for `await` using CPS plugin. * @@ -22,27 +23,34 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { /* Implements `async { ... }` using the CPS plugin. */ - protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - def lookupMember(name: String) = { - val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") + def lookupClassMember(clazz: String, name: String) = { + val asyncTrait = c.mirror.staticClass(clazz) val tpe = asyncTrait.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") + } + def lookupObjectMember(clazz: String, name: String) = { + val moduleClass = c.mirror.staticModule(clazz).moduleClass + val tpe = moduleClass.asType.toType + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") } AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") - val utils = TransformUtils[c.type](c) - val futureSystemOps = futureSystem.mkOps(c) - val awaitSym = utils.defn.Async_await - val awaitFallbackSym = lookupMember("awaitFallback") + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val futureSystemOps = futureSystem.mkOps(symTab) + val awaitSym = lookupObjectMember("scala.async.Async", "await") + val awaitFallbackSym = lookupClassMember("scala.async.continuations.AsyncBaseWithCPSFallback", "awaitFallback") // replace `await` invocations with `awaitFallback` invocations val awaitReplacer = new Transformer { override def transform(tree: Tree): Tree = tree match { case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym => - val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe))) + val typeApp = treeCopy.TypeApply(fun, atPos(tree.pos)(Ident(awaitFallbackSym)), List(atPos(tree.pos)(TypeTree(futArgTpt.tpe)))) treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate))) case _ => super.transform(tree) @@ -60,10 +68,12 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { }.asInstanceOf[Future[T]] */ + def spawn(expr: Tree) = futureSystemOps.spawn(expr.asInstanceOf[futureSystemOps.universe.Tree], execContext.tree.asInstanceOf[futureSystemOps.universe.Tree]).asInstanceOf[Tree] + val bodyWithFuture = { val tree = bodyWithAwaitFallback match { - case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) - case expr => futureSystemOps.spawn(expr) + case Block(stmts, expr) => Block(stmts, spawn(expr)) + case expr => spawn(expr) } c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) } @@ -71,20 +81,22 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { reset { bodyWithFuture.splice } } - val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) + val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset.asInstanceOf[futureSystemOps.universe.Expr[futureSystem.Fut[Any]]]).asInstanceOf[c.Expr[futureSystem.Fut[T]]] AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") bodyWithCast } - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") - val analyzer = AsyncAnalysis[c.type](c, this) + val asyncMacro = AsyncMacro(c, futureSystem) - if (!analyzer.reportUnsupportedAwaits(body.tree)) - super.asyncImpl[T](c)(body) // no unsupported awaits + if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled)) + super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits else - cpsBasedAsyncImpl[T](c)(body) // fallback to CPS + cpsBasedAsyncImpl[T](c)(body)(execContext) // fallback to CPS } } diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala index fe6e1a6..e0da5aa 100644 --- a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala @@ -13,8 +13,13 @@ import scala.concurrent.Future trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback object AsyncWithCPSFallback extends AsyncWithCPSFallback { + import scala.concurrent.{ExecutionContext, Future} - def async[T](body: T) = macro asyncImpl[T] + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } } diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala index 922d1ac..2003082 100644 --- a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala +++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala @@ -8,14 +8,17 @@ package continuations import scala.language.experimental.macros import scala.reflect.macros.Context -import scala.concurrent.Future +import scala.concurrent.{ExecutionContext, Future} trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback object CPSBasedAsync extends CPSBasedAsync { - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } } diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala index 4e8ec80..a350704 100644 --- a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala +++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala @@ -15,7 +15,9 @@ import scala.util.continuations._ */ trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback { - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = - super.cpsBasedAsyncImpl[T](c)(body) - + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + super.cpsBasedAsyncImpl[T](c)(body)(execContext) + } } diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala index 018ad05..f864ad6 100644 --- a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala @@ -7,6 +7,7 @@ package continuations import scala.util.continuations._ import scala.concurrent.{Future, Promise, ExecutionContext} +import scala.async.internal.ScalaConcurrentFutureSystem trait ScalaConcurrentCPSFallback { self: AsyncBaseWithCPSFallback => diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala new file mode 100644 index 0000000..6aeaba3 --- /dev/null +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -0,0 +1,268 @@ + +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.tools.nsc.Global +import scala.Predef._ + +private[async] trait AnfTransform { + self: AsyncMacro => + + import global._ + import reflect.internal.Flags._ + + def anfTransform(tree: Tree): Block = { + // Must prepend the () for issue #31. + val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) + + new SelectiveAnfTransform().transform(block) + } + + sealed abstract class AnfMode + + case object Anf extends AnfMode + + case object Linearizing extends AnfMode + + final class SelectiveAnfTransform extends MacroTypingTransformer { + var mode: AnfMode = Anf + + def blockToList(tree: Tree): List[Tree] = tree match { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + + def listToBlock(trees: List[Tree]): Block = trees match { + case trees @ (init :+ last) => + val pos = trees.map(_.pos).reduceLeft(_ union _) + Block(init, last).setType(last.tpe).setPos(pos) + } + + override def transform(tree: Tree): Block = { + def anfLinearize: Block = { + val trees: List[Tree] = mode match { + case Anf => anf._transformToList(tree) + case Linearizing => linearize._transformToList(tree) + } + listToBlock(trees) + } + tree match { + case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => + atOwner(tree.symbol)(anfLinearize) + case _: ModuleDef => + atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) + case _ => + anfLinearize + } + } + + private object linearize { + def transformToList(tree: Tree): List[Tree] = { + mode = Linearizing; blockToList(transform(tree)) + } + + def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val stats :+ expr = anf.transformToList(tree) + expr match { + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal(name.await, expr, tree.pos) + stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe).setPos(tree.pos) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } else { + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) { + def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe) + orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case _ => Assign(Ident(varDef.symbol), cast(orig)) + } + } + val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe) + stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setType(tree.tpe).setPos(tree.pos) + } + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } + else { + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + def typedAssign(lhs: Tree) = + localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe))) + val casesWithAssign = cases map { + case cd@CaseDef(pat, guard, body) => + val newBody = body match { + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)).setType(definitions.UnitTpe) + case _ => typedAssign(body) + } + treeCopy.CaseDef(cd, pat, guard, newBody).setType(definitions.UnitTpe) + } + val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign).setType(definitions.UnitTpe) + require(matchWithAssign.tpe != null, matchWithAssign) + stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setPos(tree.pos).setType(tree.tpe) + } + case _ => + stats :+ expr + } + } + + private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp) + ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos) + } + } + + private object trace { + private var indent = -1 + + def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe) + changeOwner(lhs, currentOwner, sym) + ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) + } + + private object anf { + def transformToList(tree: Tree): List[Tree] = { + mode = Anf; blockToList(transform(tree)) + } + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val containsAwait = tree exists isAwait + if (!containsAwait) { + List(tree) + } else tree match { + case Select(qual, sel) => + val stats :+ expr = linearize.transformToList(qual) + stats :+ treeCopy.Select(tree, expr, sel) + + case Throw(expr) => + val stats :+ expr1 = linearize.transformToList(expr) + stats :+ treeCopy.Throw(tree, expr1) + + case Typed(expr, tpt) => + val stats :+ expr1 = linearize.transformToList(expr) + stats :+ treeCopy.Typed(tree, expr1, tpt) + + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + val funStats :+ simpleFun = linearize.transformToList(fun) + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) + case Arg(expr, _, argName) => + linearize.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr1.pos) + require(valDef.tpe != null, valDef) + val stats1 = stats :+ valDef + (stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))) + } + } + + def copyApplied(tree: Tree, depth: Int): Tree = { + tree match { + case TypeApply(_, targs) => treeCopy.TypeApply(tree, simpleFun, targs) + case _ if depth == 0 => simpleFun + case Apply(fun, args) => + val newTypedArgs = map2(args.map(_.pos), argExprss(depth - 1))((pos, arg) => localTyper.typedPos(pos)(arg)) + treeCopy.Apply(tree, copyApplied(fun, depth - 1), newTypedArgs) + } + } + + val typedNewApply = copyApplied(tree, treeInfo.dissectApplied(tree).applyDepth) + + funStats ++ argStatss.flatten.flatten :+ typedNewApply + + case Block(stats, expr) => + (stats :+ expr).flatMap(linearize.transformToList) + + case ValDef(mods, name, tpt, rhs) => + if (rhs exists isAwait) { + val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) + stats.foreach(changeOwner(_, currOwner, currOwner.owner)) + stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) + } else List(tree) + + case Assign(lhs, rhs) => + val stats :+ expr = linearize.transformToList(rhs) + stats :+ treeCopy.Assign(tree, lhs, expr) + + case If(cond, thenp, elsep) => + val condStats :+ condExpr = linearize.transformToList(cond) + val thenBlock = linearize.transformToBlock(thenp) + val elseBlock = linearize.transformToBlock(elsep) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock) + condStats :+ iff + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = linearize.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. + val block = linearize.transformToBlock(body) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos) + (vd, (b.symbol, vd.symbol)) + }).unzip + val (from, to) = mappings.unzip + val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] + val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) + treeCopy.CaseDef(tree, pat, guard, newBlock) + } + // Refer to comments the translation of `If` above. + val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs) + scrutStats :+ typedMatch + + case LabelDef(name, params, rhs) => + List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = linearize.transformToList(fun) + funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) + + case _ => + List(tree) + } + } + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala new file mode 100644 index 0000000..122109e --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable + +trait AsyncAnalysis { + self: AsyncMacro => + + import global._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = { + val analyzer = new UnsupportedAwaitAnalyzer(report) + analyzer.traverse(tree) + analyzer.hasUnsupportedAwaits + } + + private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser { + var hasUnsupportedAwaits = false + + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested ${kind}") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def nestedMethod(defDef: DefDef) { + reportUnsupportedAwait(defDef, "nested method") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + override def patMatFunction(tree: Match) { + reportUnsupportedAwait(tree, "nested function") + } + + override def traverse(tree: Tree) { + def containsAwait = tree exists isAwait + tree match { + case Try(_, _, _) if containsAwait => + reportUnsupportedAwait(tree, "try/catch") + super.traverse(tree) + case Return(_) => + abort(tree.pos, "return is illegal within a async block") + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + // TODO lift this restriction + abort(tree.pos, "lazy vals are illegal within an async block") + case CaseDef(_, guard, _) if guard exists isAwait => + // TODO lift this restriction + reportUnsupportedAwait(tree, "pattern guard") + case _ => + super.traverse(tree) + } + } + + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + reportError(tree.pos, s"await must not be used under a $whyUnsupported.") + } + badAwaits.nonEmpty + } + + private def reportError(pos: Position, msg: String) { + hasUnsupportedAwaits = true + if (report) + abort(pos, msg) + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala new file mode 100644 index 0000000..ca06039 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.internal.annotations.compileTimeOnly +import scala.reflect.macros.Context + +/** + * A base class for the `async` macro. Subclasses must provide: + * + * - Concrete types for a given future system + * - Tree manipulations to create and complete the equivalent of Future and Promise + * in that system. + * - The `async` macro declaration itself, and a forwarder for the macro implementation. + * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * + * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + */ +abstract class AsyncBase { + self => + + type FS <: FutureSystem + val futureSystem: FS + + /** + * A call to `await` must be nested in an enclosing `async` block. + * + * A call to `await` does not block the current thread, rather it is a delimiter + * used by the enclosing `async` macro. Code following the `await` + * call is executed asynchronously, when the argument of `await` has been completed. + * + * @param awaitable the future from which a value is awaited. + * @tparam T the type of that value. + * @return the value. + */ + @compileTimeOnly("`await` must be enclosed in an `async` block") + def await[T](awaitable: futureSystem.Fut[T]): T = ??? + + protected[async] def fallbackEnabled = false + + def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + + val asyncMacro = AsyncMacro(c, futureSystem) + + val code = asyncMacro.asyncTransform[T]( + body.tree.asInstanceOf[asyncMacro.global.Tree], + execContext.tree.asInstanceOf[asyncMacro.global.Tree], + fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]).asInstanceOf[Tree] + + for (t <- code) + t.pos = t.pos.makeTransparent + + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") + c.Expr[futureSystem.Fut[T]](code) + } +} diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala new file mode 100644 index 0000000..4334088 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import language.experimental.macros +import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable + +object AsyncId extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) +} + +/** + * A trivial implementation of [[FutureSystem]] that performs computations + * on the current thread. Useful for testing. + */ +object IdentityFutureSystem extends FutureSystem { + + class Prom[A] { + var a: A = _ + } + + type Fut[A] = A + type ExecContext = Unit + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def execContextType: Type = weakTypeOf[Unit] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + new Prom() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.a + } + + def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + fun.splice.apply(util.Success(future.splice)) + Expr[Unit](Literal(Constant(()))).splice + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.a = value.splice.get + Expr[Unit](Literal(Constant(()))).splice + } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? + } +} diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala new file mode 100644 index 0000000..23cc611 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -0,0 +1,32 @@ +package scala.async.internal + +import scala.tools.nsc.Global +import scala.tools.nsc.transform.TypingTransformers + +object AsyncMacro { + def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { + import language.reflectiveCalls + val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}] + new AsyncMacro { + val global: powerContext.universe.type = powerContext.universe + val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper + val futureSystem: futureSystem0.type = futureSystem0 + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + } + } +} + +private[async] trait AsyncMacro + extends TypingTransformers + with AnfTransform with TransformUtils with Lifter + with ExprBuilder with AsyncTransform with AsyncAnalysis { + + val global: Global + val callSiteTyper: global.analyzer.Typer + val macroApplication: global.Tree + + def macroPos = macroApplication.pos.makeTransparent + def atMacroPos(t: global.Tree) = global.atPos(macroPos)(t) + +} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala new file mode 100644 index 0000000..c755c87 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -0,0 +1,177 @@ +package scala.async.internal + +trait AsyncTransform { + self: AsyncMacro => + + import global._ + + def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) + (implicit resultType: WeakTypeTag[T]): Tree = { + + reportUnsupportedAwaits(body, report = !cpsFallbackEnabled) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = anfTransform(body) + + val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) + + val applyDefDefDummyBody: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + + val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) + + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) + callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(())))) + t + } + + val asyncBlock: AsyncBlock = { + val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + buildAsyncBlock(anfTree, symLookup) + } + + logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + + def startStateMachine: Tree = { + val stateMachineSpliced: Tree = spliceMethodBodies( + liftables(asyncBlock.asyncStates), + stateMachine, + atMacroPos(asyncBlock.onCompleteHandler[T]), + atMacroPos(asyncBlock.resumeFunTree[T].rhs) + ) + + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + Block(List[Tree]( + stateMachineSpliced, + ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) + ), + futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + + val isSimple = asyncBlock.asyncStates.size == 1 + if (isSimple) + futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` + else + startStateMachine + } + + def logDiagnostics(anfTree: Tree, states: Seq[String]) { + def location = try { + macroPos.source.path + } catch { + case _: UnsupportedOperationException => + macroPos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } + + def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, + resumeBody: Tree): Tree = { + + val liftedSyms = liftables.map(_.symbol).toSet + val stateMachineClass = tree.symbol + liftedSyms.foreach { + sym => + if (sym != null) { + sym.owner = stateMachineClass + if (sym.isModule) + sym.moduleClass.owner = stateMachineClass + } + } + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will be only be run on the RHS of `def foo`. + class UseFields extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = tree match { + case _ if currentOwner == stateMachineClass => + super.transform(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + atOwner(currentOwner) { + val fieldSym = tree.symbol + val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) + changeOwner(set, tree.symbol, currentOwner) + localTyper.typedPos(tree.pos)(set) + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + atPos(tree.pos) { + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) + } + case _ => + super.transform(tree) + } + } + + val liftablesUseFields = liftables.map { + case vd: ValDef => vd + case x => + val useField = new UseFields() + //.substituteSymbols(fromSyms, toSyms) + useField.atOwner(stateMachineClass)(useField.transform(x)) + } + + tree.children.foreach { + t => + new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) + } + val treeSubst = tree + + def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { + val spliceeAnfFixedOwnerSyms = body + val useField = new UseFields() + val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) + val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) + treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) + } + + liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) + + val result0 = transformAt(treeSubst) { + case t@Template(parents, self, stats) => + (ctx: analyzer.Context) => { + treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) + } + } + val result = transformAt(result0) { + case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) + typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) + val res = fixup(dd, changed, ctx) + res + } + result + } +} diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala index 1ade5f0..8700bd6 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/internal/AsyncUtils.scala @@ -1,7 +1,7 @@ /* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal object AsyncUtils { diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index ca46a83..e0da874 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -1,23 +1,24 @@ /* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api +import scala.Some -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { - builder => +trait ExprBuilder { + builder: AsyncMacro => - val utils = TransformUtils[c.type](c) - - import c.universe._ - import utils._ + import global._ import defn._ - lazy val futureSystemOps = futureSystem.mkOps(c) + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } val stateAssigner = new StateAssigner val labelDefStates = collection.mutable.Map[Symbol, Int]() @@ -27,22 +28,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkHandlerCaseForState: CaseDef - def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None def stats: List[Tree] - final def body: c.Tree = stats match { + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { case stat :: Nil => stat case init :+ last => Block(init, last) } } /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -51,7 +57,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) @@ -62,25 +68,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nestState`.`` */ - final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree mkHandlerCase(state, stats :+ callOnComplete) } - override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { val tryGetTree = Assign( Ident(awaitable.resultName), - TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) ) /* if (tr.isFailure) - * result$async.complete(tr.asInstanceOf[Try[T]]) + * result.complete(tr.asInstanceOf[Try[T]]) * else { * <resultName> = tr.get.asInstanceOf[<resultType>] * <nextState> @@ -88,13 +94,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } */ val ifIsFailureTree = - If(Select(Ident(name.tr), Try_isFailure), + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), - c.Expr[scala.util.Try[T]]( - TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")), + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply) + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) ) Some(mkHandlerCase(state, List(ifIsFailureTree))) @@ -107,19 +113,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /* * Builder for a single state of an async method. */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() + private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None - private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) - - def +=(stat: c.Tree): this.type = { + def +=(stat: Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - def addStat() = stats += renameReset(stat) + def addStat() = stats += stat stat match { - case _: DefDef => // these have been lifted. case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) @@ -132,22 +135,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithAwait(awaitable: Awaitable, nextState: Int): AsyncState = { - val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) } - def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { - // 1. build changed if-else tree - // 2. insert that tree at the end of the current state - val cond = renameReset(condTree) - def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply) - this += If(cond, mkBranch(thenState), mkBranch(elseState)) + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -161,23 +160,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ - def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { - case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) - case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") - } - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) } // 2. insert changed match tree at the end of the current state - this += Match(renameReset(scrutTree), newCases) + this += Match(scrutTree, newCases) new AsyncStateWithoutAwait(stats.toList, state) } - def resultWithLabel(startLabelState: Int): AsyncState = { - this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply) + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -194,24 +190,22 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param expr the last expression of the block * @param startState the start state * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { val asyncStates = ListBuffer[AsyncState]() - var stateBuilder = new AsyncStateBuilder(startState, toRename) + var stateBuilder = new AsyncStateBuilder(startState, symLookup) var currState = startState - /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ - def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true case _ => false - }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + }) abort(tree.pos, "await must not be used in this position") def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) } import stateAssigner.nextState @@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val afterAwaitState = nextState() - val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe) + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => - checkForUnsupportedAwait(rhs) - stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) @@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) @@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val afterMatchState = nextState() asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates) + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) for ((cas, num) <- cases.zipWithIndex) { val (stats, expr) = statsAndExpr(cas.body) @@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case ld@LabelDef(name, params, rhs) if rhs exists isAwait => val startLabelState = nextState() val afterLabelState = nextState() - asyncStates += stateBuilder.resultWithLabel(startLabelState) + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: trait AsyncBlock { def asyncStates: List[AsyncState] - def onCompleteHandler[T: c.WeakTypeTag]: Tree + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } - def resumeFunTree[T]: Tree + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) } - def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = { + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { val Block(stats, expr) = block val startState = stateAssigner.nextState() val endState = Int.MaxValue - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) new AsyncBlock { def asyncStates = blockBuilder.asyncStates.toList @@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkCombinedHandlerCases[T]: List[CaseDef] = { val caseForLastState: CaseDef = { val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) + val lastStateBody = Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) mkHandlerCase(lastState.state, rhs.tree) } asyncStates.toList match { @@ -327,18 +323,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val initStates = asyncStates.init /** - * // assumes tr: Try[Any] is in scope. - * // - * state match { - * case 0 => { - * x11 = tr.get.asInstanceOf[Double]; - * state = 1; - * resume() - * } - */ - def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) - - /** * def resume(): Unit = { * try { * state match { @@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } * } */ - def resumeFunTree[T]: Tree = + def resumeFunTree[T]: DefDef = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Try( - Match(Ident(name.state), mkCombinedHandlerCases[T]), + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), List( CaseDef( - Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), - EmptyTree, + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), Block(List({ - val t = c.Expr[Throwable](Ident(name.tr)) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree - }), c.literalUnit.tree))), EmptyTree)) + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) } } @@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: case _ => false } - private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - - private val internalSyms = origTree.collect { - case dt: DefTree => dt.symbol - } + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) - private def mkResumeApply = Apply(Ident(name.resume), Nil) + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) - private def mkStateTree(nextState: Int): c.Tree = - Assign(Ident(name.state), c.literal(nextState).tree) + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs, c.literalUnit.tree)) + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + private def literalUnit = Literal(Constant(())) } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index a050bec..101b7bf 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -1,11 +1,12 @@ /* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal import scala.language.higherKinds import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable /** * An abstraction over a future system. @@ -14,7 +15,7 @@ import scala.reflect.macros.Context * customize the code generation. * * The API mirrors that of `scala.concurrent.Future`, see the instance - * [[scala.async.ScalaConcurrentFutureSystem]] for an example of how + * [[ScalaConcurrentFutureSystem]] for an example of how * to implement this. */ trait FutureSystem { @@ -26,12 +27,10 @@ trait FutureSystem { type ExecContext trait Ops { - val context: reflect.macros.Context + val universe: reflect.internal.SymbolTable - import context.universe._ - - /** Lookup the execution context, typically with an implicit search */ - def execContext: Expr[ExecContext] + import universe._ + def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) def promType[A: WeakTypeTag]: Type def execContextType: Type @@ -52,13 +51,14 @@ trait FutureSystem { /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] - def spawn(tree: context.Tree): context.Tree = - future(context.Expr[Unit](tree))(execContext).tree + def spawn(tree: Tree, execContext: Tree): Tree = + future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree + // TODO Why is this needed? def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] } - def mkOps(c: Context): Ops { val context: c.type } + def mkOps(c: SymbolTable): Ops { val universe: c.type } } object ScalaConcurrentFutureSystem extends FutureSystem { @@ -69,18 +69,13 @@ object ScalaConcurrentFutureSystem extends FutureSystem { type Fut[A] = Future[A] type ExecContext = ExecutionContext - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c - - import context.universe._ + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c - def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match { - case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext") - case context => context - }) + import universe._ - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] - def execContextType: Type = c.weakTypeOf[ExecutionContext] + def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] + def execContextType: Type = weakTypeOf[ExecutionContext] def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { Promise[A]() @@ -101,7 +96,7 @@ object ScalaConcurrentFutureSystem extends FutureSystem { def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.complete(value.splice) - context.literalUnit.splice + Expr[Unit](Literal(Constant(()))).splice } def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { @@ -109,49 +104,3 @@ object ScalaConcurrentFutureSystem extends FutureSystem { } } } - -/** - * A trivial implementation of [[scala.async.FutureSystem]] that performs computations - * on the current thread. Useful for testing. - */ -object IdentityFutureSystem extends FutureSystem { - - class Prom[A](var a: A) - - type Fut[A] = A - type ExecContext = Unit - - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c - - import context.universe._ - - def execContext: Expr[ExecContext] = c.literalUnit - - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] - def execContextType: Type = c.weakTypeOf[Unit] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - new Prom(null.asInstanceOf[A]) - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.a - } - - def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - fun.splice.apply(util.Success(future.splice)) - context.literalUnit.splice - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.a = value.splice.get - context.literalUnit.splice - } - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? - } -} diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala new file mode 100644 index 0000000..f49dcbb --- /dev/null +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -0,0 +1,150 @@ +package scala.async.internal + +trait Lifter { + self: AsyncMacro => + import global._ + + /** + * Identify which DefTrees are used (including transitively) which are declared + * in some state but used (including transitively) in another state. + * + * These will need to be lifted to class members of the state machine. + */ + def liftables(asyncStates: List[AsyncState]): List[Tree] = { + object companionship { + private val companions = collection.mutable.Map[Symbol, Symbol]() + private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() + private def record(sym1: Symbol, sym2: Symbol) { + companions(sym1) = sym2 + companions(sym2) = sym1 + } + + def record(defs: List[Tree]) { + // Keep note of local companions so we rename them consistently + // when lifting. + val comps = for { + cd@ClassDef(_, _, _, _) <- defs + md@ModuleDef(_, _, _) <- defs + if (cd.name.toTermName == md.name) + } record(cd.symbol, md.symbol) + } + def companionOf(sym: Symbol): Symbol = { + companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) + } + } + + + val defs: Map[Tree, Int] = { + /** Collect the DefTrees directly enclosed within `t` that have the same owner */ + def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case dt: DefTree => dt :: Nil + case _: Function => Nil + case t => + val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) + companionship.record(childDefs) + childDefs + } + asyncStates.flatMap { + asyncState => + val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) + defs.map((_, asyncState.state)) + }.toMap + } + + // In which block are these symbols defined? + val symToDefiningState: Map[Symbol, Int] = defs.map { + case (k, v) => (k.symbol, v) + } + + // The definitions trees + val symToTree: Map[Symbol, Tree] = defs.map { + case (k, v) => (k.symbol, k) + } + + // The direct references of each definition tree + val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { + case tree => (tree.symbol, tree.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol + }) + }.toMap + + // The direct references of each block, excluding references of `DefTree`-s which + // are already accounted for. + val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { + val refs: List[(Int, Symbol)] = asyncStates.flatMap( + asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + }) + ) + toMultiMap(refs) + } + + def liftableSyms: Set[Symbol] = { + val liftableMutableSet = collection.mutable.Set[Symbol]() + def markForLift(sym: Symbol) { + if (!liftableMutableSet(sym)) { + liftableMutableSet += sym + + // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars + // stays in its original location, so things that it refers to need not be lifted. + if (!(sym.isVal || sym.isVar)) + defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) + } + } + // Start things with DefTrees directly referenced from statements from other states... + val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { + case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) + } + // .. and likewise for DefTrees directly referenced by other DefTrees from other states + val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { + case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) + } + // Mark these for lifting, which will follow transitive references. + (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) + liftableMutableSet.toSet + } + + val lifted = liftableSyms.map(symToTree).toList.map { + case vd@ValDef(_, _, tpt, rhs) => + import reflect.internal.Flags._ + val sym = vd.symbol + sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) + sym.name = name.fresh(sym.name.toTermName) + sym.modifyInfo(_.deconst) + ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + import reflect.internal.Flags._ + val sym = dd.symbol + sym.name = this.name.fresh(sym.name.toTermName) + sym.setFlag(PRIVATE | LOCAL) + DefDef(dd.symbol, rhs).setPos(dd.pos) + case cd@ClassDef(_, _, _, impl) => + import reflect.internal.Flags._ + val sym = cd.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + companionship.companionOf(cd.symbol) match { + case NoSymbol => + case moduleSymbol => + moduleSymbol.name = sym.name.toTermName + moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName + } + ClassDef(cd.symbol, impl).setPos(cd.pos) + case md@ModuleDef(_, _, impl) => + import reflect.internal.Flags._ + val sym = md.symbol + companionship.companionOf(md.symbol) match { + case NoSymbol => + sym.name = name.fresh(sym.name.toTermName) + sym.moduleClass.name = sym.name.toTypeName + case classSymbol => // will be renamed by `case ClassDef` above. + } + ModuleDef(md.symbol, impl).setPos(md.pos) + case td@TypeDef(_, _, _, rhs) => + import reflect.internal.Flags._ + val sym = td.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + TypeDef(td.symbol, rhs).setPos(td.pos) + } + lifted + } +} diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index bc60a6d..cdde7a4 100644 --- a/src/main/scala/scala/async/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -2,7 +2,7 @@ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ -package scala.async +package scala.async.internal private[async] final class StateAssigner { private var current = -1 @@ -11,4 +11,4 @@ private[async] final class StateAssigner { current += 1 current } -}
\ No newline at end of file +} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala new file mode 100644 index 0000000..70237bc --- /dev/null +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -0,0 +1,251 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import reflect.ClassTag +import scala.reflect.macros.runtime.AbortMacroException + +/** + * Utilities used in both `ExprBuilder` and `AnfTransform`. + */ +private[async] trait TransformUtils { + self: AsyncMacro => + + import global._ + + object name { + val resume = newTermName("resume") + val apply = newTermName("apply") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" + + val state = newTermName("state") + val result = newTermName("result") + val execContext = newTermName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val tr = newTermName("tr") + val t = newTermName("throwable") + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString + } + + def isAwait(fun: Tree) = + fun.symbol == defn.Async_await + + private lazy val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + private def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val paramss = fun.tpe.paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + private def argName(fun: Tree): ((Int, Int) => String) = { + val paramss = fun.tpe.paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { + self.splice.contains(elem.splice) + } + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val TryClass = rootMirror.staticClass("scala.util.Try") + val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) + val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val AsyncClass = rootMirror.staticClass("scala.async.internal.AsyncBase") + val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol) + } + + def isSafeToInline(tree: Tree) = { + treeInfo.isExprSafeToInline(tree) + } + + /** Map a list of arguments to: + * - A list of argument Trees + * - A list of auxillary results. + * + * The function unwraps and rewraps the `arg :_*` construct. + * + * @param args The original argument trees + * @param f A function from argument (with '_*' unwrapped) and argument index to argument. + * @tparam A The type of the auxillary result + */ + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + args match { + case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => + val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip + val exprs = argExprs :+ atPos(lastArgExpr.pos.makeTransparent)(Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR))) + (a, exprs) + case args => + args.zipWithIndex.map(f.tupled).unzip + } + } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } + + + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descent stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def nestedMethod(module: DefDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + def patMatFunction(tree: Match) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + val isInByName = isByName(fun) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) + + abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { + currentOwner = callSiteTyper.context.owner + + def currOwner: Symbol = currentOwner + + localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + } + + def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { + object trans extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = { + if (f.isDefinedAt(tree)) { + f(tree)(localTyper.context) + } else super.transform(tree) + } + } + trans.transform(tree) + } + + def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { + new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) + tree + } + + class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) + extends ChangeOwnerTraverser(oldowner, newowner) { + + override def traverse(tree: Tree) { + tree match { + case _: DefTree => change(tree.symbol.moduleClass) + case _ => + } + super.traverse(tree) + } + } + + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + + // Attributed version of `TreeGen#mkCastPreservingAnnotations` + def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { + atPos(tree.pos) { + val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias) + Typed(casted, TypeTree(tp)).setType(tp) + } + } +} diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index deaee03..770c0f9 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -7,6 +7,7 @@ package scala.async import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId import AsyncId._ import tools.reflect.ToolBox @@ -15,9 +16,9 @@ class TreeInterrogation { @Test def `a minimal set of vals are lifted to vars`() { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes") + val tb = mkToolbox(s"-cp ${toolboxClasspath}") val tree = tb.parse( - """| import _root_.scala.async.AsyncId._ + """| import _root_.scala.async.internal.AsyncId._ | async { | val x = await(1) | val y = x * 2 @@ -40,8 +41,7 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1$1", "await$2$1")) val defDefs = tree1.collect { case t: Template => @@ -52,7 +52,7 @@ class TreeInterrogation { && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name } }.flatten - defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume$async", "<init>")) + defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume", "<init>")) } } @@ -68,17 +68,15 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten") + val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid") import scala.async.Async._ val tree = tb.parse( - """ import _root_.scala.async.AsyncId.{async, await} - | def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" - | val res = async { - | var i = 0 - | def get = async {i += 1; i} - | foo[Int](await(get))(await(get) :: Nil : _*) + """ import _root_.scala.async.internal.AsyncId.{async, await} + | import reflect.runtime.universe._ + | async { + | implicit def view(a: Int): String = "" + | await(0).length | } - | res | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala index 2569303..6ebc9ca 100644 --- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala +++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala @@ -5,121 +5,33 @@ package scala.async package neg -/** - * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> - */ - import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId @RunWith(classOf[JUnit4]) class LocalClasses0Spec { - @Test - def `reject a local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | } - """.stripMargin - } + def localClassCrashIssue16() { + import AsyncId.{async, await} + async { + class B { def f = 1 } + await(new B()).f + } mustBe 1 } @Test - def `reject a local class 2`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | val fut = Future { 5 } - | val x = await(fut) - | x - | } - """.stripMargin - } + def nestedCaseClassAndModuleAllowed() { + import AsyncId.{await, async} + async { + trait Base { def base = 0} + await(0) + case class Person(name: String) extends Base + val fut = async { "bob" } + val x = Person(await(fut)) + x.base + x.name + } mustBe "bob" } - - @Test - def `reject a local class 3`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class Person(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a local class with symbols in its name`() { - expectError("Local case class :: illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class ::(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a nested local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = 2 + 2 - | var y = 0 - | if (x > 0) { - | case class Person(name: String) - | y = await(fut) - | } else { - | y = x - | } - | y - | } - """.stripMargin - } - } - - @Test - def `reject a local singleton object`() { - expectError("Local object Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | object Person { val name = "Joe" } - | } - """.stripMargin - } - } - } diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index b0d5fde..ba388c5 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -25,7 +25,7 @@ class NakedAwait { def `await not allowed in by-name argument`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(a: Int)(b: => Int) = 0 | async { foo(0)(await(0)) } """.stripMargin @@ -36,7 +36,7 @@ class NakedAwait { def `await not allowed in boolean short circuit argument 1`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { true && await(false) } """.stripMargin } @@ -46,7 +46,7 @@ class NakedAwait { def `await not allowed in boolean short circuit argument 2`() { expectError("await must not be used under a by-name argument.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { true || await(false) } """.stripMargin } @@ -56,7 +56,7 @@ class NakedAwait { def nestedObject() { expectError("await must not be used under a nested object.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { object Nested { await(false) } } """.stripMargin } @@ -66,7 +66,7 @@ class NakedAwait { def nestedTrait() { expectError("await must not be used under a nested trait.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { trait Nested { await(false) } } """.stripMargin } @@ -76,7 +76,7 @@ class NakedAwait { def nestedClass() { expectError("await must not be used under a nested class.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { class Nested { await(false) } } """.stripMargin } @@ -86,7 +86,7 @@ class NakedAwait { def nestedFunction() { expectError("await must not be used under a nested function.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { () => { await(false) } } """.stripMargin } @@ -96,7 +96,7 @@ class NakedAwait { def nestedPatMatFunction() { expectError("await must not be used under a nested class.") { // TODO more specific error message """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { { case x => { await(false) } } : PartialFunction[Any, Any] } """.stripMargin } @@ -106,7 +106,7 @@ class NakedAwait { def tryBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { await(false) } catch { case _ => } } """.stripMargin } @@ -116,7 +116,7 @@ class NakedAwait { def catchBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { () } catch { case _ => await(false) } } """.stripMargin } @@ -126,17 +126,27 @@ class NakedAwait { def finallyBody() { expectError("await must not be used under a try/catch.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { try { () } finally { await(false) } } """.stripMargin } } @Test + def guard() { + expectError("await must not be used under a pattern guard.") { + """ + | import _root_.scala.async.internal.AsyncId._ + | async { 1 match { case _ if await(true) => } } + """.stripMargin + } + } + + @Test def nestedMethod() { expectError("await must not be used under a nested method.") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | async { def foo = await(false) } """.stripMargin } @@ -146,7 +156,7 @@ class NakedAwait { def returnIllegal() { expectError("return is illegal") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(): Any = async { return false } | () | @@ -158,7 +168,7 @@ class NakedAwait { def lazyValIllegal() { expectError("lazy vals are illegal") { """ - | import _root_.scala.async.AsyncId._ + | import _root_.scala.async.internal.AsyncId._ | def foo(): Any = async { val x = { lazy val y = 0; y } } | () | diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala index 4a7a958..7c42024 100644 --- a/src/test/scala/scala/async/package.scala +++ b/src/test/scala/scala/async/package.scala @@ -42,7 +42,22 @@ package object async { m.mkToolBox(options = compileOptions) } - def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) { + def scalaBinaryVersion: String = { + val Pattern = """(\d+\.\d+)\..*""".r + scala.util.Properties.versionNumberString match { + case Pattern(v) => v + case _ => "" + } + } + + def toolboxClasspath = { + val f = new java.io.File(s"target/scala-${scalaBinaryVersion}/classes") + if (!f.exists) sys.error(s"output directory ${f.getAbsolutePath} does not exist.") + f.getAbsolutePath + } + + def expectError(errorSnippet: String, compileOptions: String = "", + baseCompileOptions: String = s"-cp ${toolboxClasspath}")(code: String) { intercept[ToolBoxError] { eval(code, compileOptions + " " + baseCompileOptions) }.getMessage mustContain errorSnippet diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 7be6299..c8cec28 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -13,6 +13,7 @@ import scala.async.Async.{async, await} import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import scala.async.internal.AsyncId class AnfTestClass { @@ -114,8 +115,6 @@ class AnfTransformSpec { @Test def `inlining block does not produce duplicate definition`() { - import scala.async.AsyncId - AsyncId.async { val f = 12 val x = AsyncId.await(f) @@ -132,8 +131,6 @@ class AnfTransformSpec { @Test def `inlining block in tail position does not produce duplicate definition`() { - import scala.async.AsyncId - AsyncId.async { val f = 12 val x = AsyncId.await(f) @@ -176,7 +173,7 @@ class AnfTransformSpec { @Test def nestedAwaitAsBareExpression() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { await(await("").isEmpty) } @@ -186,7 +183,7 @@ class AnfTransformSpec { @Test def nestedAwaitInBlock() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { () await(await("").isEmpty) @@ -197,7 +194,7 @@ class AnfTransformSpec { @Test def nestedAwaitInIf() { import ExecutionContext.Implicits.global - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} val result = async { if ("".isEmpty) await(await("").isEmpty) @@ -208,7 +205,7 @@ class AnfTransformSpec { @Test def byNameExpressionsArentLifted() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(ignored: => Any, b: Int) = b val result = async { foo(???, await(1)) @@ -218,7 +215,7 @@ class AnfTransformSpec { @Test def evaluationOrderRespected() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 @@ -233,19 +230,19 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSection1() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0" val res = async { var i = 0 def get = {i += 1; i} - foo(get)(get) + foo(get)(await(get)) } res mustBe "a0 = 1, b0 = 2" } @Test def awaitInNonPrimaryParamSection2() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" val res = async { var i = 0 @@ -257,7 +254,7 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSectionWithLazy1() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: => Int)(b: Int) = b val res = async { def get = async {0} @@ -268,7 +265,7 @@ class AnfTransformSpec { @Test def awaitInNonPrimaryParamSectionWithLazy2() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: Int)(b: => Int) = a val res = async { def get = async {0} @@ -279,7 +276,7 @@ class AnfTransformSpec { @Test def awaitWithLazy() { - import _root_.scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo[T](a: Int, b: => Int) = a val res = async { def get = async {0} @@ -290,7 +287,7 @@ class AnfTransformSpec { @Test def awaitOkInReciever() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} class Foo { def bar(a: Int)(b: Int) = a + b } async { await(async(new Foo)).bar(1)(2) @@ -299,7 +296,7 @@ class AnfTransformSpec { @Test def namedArgumentsRespectEvaluationOrder() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} def foo(a: Int, b: Int) = (a, b) val result = async { var i = 0 @@ -314,7 +311,7 @@ class AnfTransformSpec { @Test def namedAndDefaultArgumentsRespectEvaluationOrder() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def next() = { i += 1; @@ -332,7 +329,7 @@ class AnfTransformSpec { @Test def repeatedParams1() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def foo(a: Int, b: Int*) = b.toList def id(i: Int) = i @@ -343,7 +340,7 @@ class AnfTransformSpec { @Test def repeatedParams2() { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} var i = 0 def foo(a: Int, b: Int*) = b.toList def id(i: Int) = i @@ -351,4 +348,64 @@ class AnfTransformSpec { foo(await(0), List(id(1), id(2), id(3)): _*) } mustBe (List(1, 2, 3)) } + + @Test + def awaitInThrow() { + import _root_.scala.async.internal.AsyncId.{async, await} + intercept[Exception]( + async { + throw new Exception("msg: " + await(0)) + } + ).getMessage mustBe "msg: 0" + } + + @Test + def awaitInTyped() { + import _root_.scala.async.internal.AsyncId.{async, await} + async { + (("msg: " + await(0)): String).toString + } mustBe "msg: 0" + } + + + @Test + def awaitInAssign() { + import _root_.scala.async.internal.AsyncId.{async, await} + async { + var x = 0 + x = await(1) + x + } mustBe 1 + } + + @Test + def caseBodyMustBeTypedAsUnit() { + import _root_.scala.async.internal.AsyncId.{async, await} + val Up = 1 + val Down = 2 + val sign = async { + await(1) match { + case Up => 1.0 + case Down => -1.0 + } + } + sign mustBe 1.0 + } + + @Test + def awaitInImplicitApply() { + val tb = mkToolbox(s"-cp ${toolboxClasspath}") + val tree = tb.typeCheck(tb.parse { + """ + | import language.implicitConversions + | import _root_.scala.async.internal.AsyncId.{async, await} + | implicit def view(a: Int): String = "" + | async { + | await(0).length + | } + """.stripMargin + }) + val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x } + applyImplicitView.map(_.toString) mustBe List("view(a$1)") + } } diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala index 9d1df21..8081ee7 100644 --- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala +++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala @@ -9,11 +9,12 @@ package hygiene import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import scala.async.internal.AsyncId @RunWith(classOf[JUnit4]) class HygieneSpec { - import scala.async.AsyncId.{async, await} + import AsyncId.{async, await} @Test def `is hygenic`() { diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala index e2b1ca6..fc438a1 100644 --- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala +++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala @@ -13,6 +13,7 @@ import scala.async.Async.{async, await} import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId class TestIfElseClass { diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala index 1f1033a..b8d88fb 100644 --- a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala +++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala @@ -9,6 +9,7 @@ package ifelse0 import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId @RunWith(classOf[JUnit4]) class WhileSpec { @@ -64,4 +65,4 @@ class WhileSpec { } result mustBe (100) } -}
\ No newline at end of file +} diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala index 7624838..7c392ab 100644 --- a/src/test/scala/scala/async/run/match0/Match0.scala +++ b/src/test/scala/scala/async/run/match0/Match0.scala @@ -13,6 +13,7 @@ import scala.async.Async.{async, await} import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId class TestMatchClass { @@ -111,4 +112,38 @@ class MatchSpec { } result mustBe (3) } + + @Test def duplicateBindName() { + import AsyncId.{async, await} + def m4(m: Any) = async { + m match { + case buf: String => + await(0) + case buf: Double => + await(2) + } + } + m4("") mustBe 0 + } + + @Test def bugCastBoxedUnitToStringMatch() { + import scala.async.internal.AsyncId.{async, await} + def foo = async { + val p2 = await(5) + "foo" match { + case p3: String => + p2.toString + } + } + foo mustBe "5" + } + + @Test def bugCastBoxedUnitToStringIf() { + import scala.async.internal.AsyncId.{async, await} + def foo = async { + val p2 = await(5) + if (true) p2.toString else p2.toString + } + foo mustBe "5" + } } diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala index ee0a78e..409f70a 100644 --- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala +++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala @@ -5,6 +5,7 @@ package nesteddef import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test +import scala.async.internal.AsyncId @RunWith(classOf[JUnit4]) class NestedDef { @@ -37,4 +38,60 @@ class NestedDef { } result mustBe ((0d, 44d, 2)) } + + // We must lift `foo` and `bar` in the next two tests. + @Test + def nestedDefTransitive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + foo + } + result mustBe 0 + } + + @Test + def nestedDefTransitive2() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + 0 + } + result mustBe 0 + } + + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + bar + } + result mustBe 0 + } + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive2() { + import AsyncId._ + val result = async { + val a = 0 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + val x = await(a) - 1 + bar + } + result mustBe 0 + } } diff --git a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala index e2c69d0..ba9c9be 100644 --- a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala +++ b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala @@ -6,6 +6,7 @@ package scala.async package run package noawait +import scala.async.internal.AsyncId import AsyncId._ import org.junit.Test import org.junit.runner.RunWith diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala index 83f5a2d..ec2278f 100644 --- a/src/test/scala/scala/async/run/toughtype/ToughType.scala +++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala @@ -13,6 +13,7 @@ import scala.async.Async._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import scala.async.internal.AsyncId object ToughTypeObject { @@ -67,4 +68,74 @@ class ToughTypeSpec { await(f(2)) } mustBe 3 } + + @Test def existentialBindIssue19() { + import AsyncId.{await, async} + def m7(a: Any) = async { + a match { + case s: Seq[_] => + val x = s.size + var ss = s + ss = s + await(x) + } + } + m7(Nil) mustBe 0 + } + + @Test def existentialBind2Issue19() { + import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global + def conjure[T]: T = null.asInstanceOf[T] + + def m3 = async { + val p: List[Option[_]] = conjure[List[Option[_]]] + await(future(1)) + } + + def m4 = async { + await(future[List[_]](Nil)) + } + } + + @Test def singletonTypeIssue17() { + import AsyncId.{async, await} + class A { class B } + async { + val a = new A + def foo(b: a.B) = 0 + await(foo(new a.B)) + } + } + + @Test def existentialMatch() { + import AsyncId.{async, await} + trait Container[+A] + case class ContainerImpl[A](value: A) extends Container[A] + def foo: Container[_] = async { + val a: Any = List(1) + a match { + case buf: Seq[_] => + val foo = await(5) + val e0 = buf(0) + ContainerImpl(e0) + } + } + foo + } + + @Test def existentialIfElse0() { + import AsyncId.{async, await} + trait Container[+A] + case class ContainerImpl[A](value: A) extends Container[A] + def foo: Container[_] = async { + val a: Any = List(1) + if (true) { + val buf: Seq[_] = List(1) + val foo = await(5) + val e0 = buf(0) + ContainerImpl(e0) + } else ??? + } + foo + } } |