aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2013-08-14 07:51:10 -0700
committerPhilipp Haller <hallerp@gmail.com>2013-08-14 07:51:10 -0700
commit0c5a1ea043c72bbc3568a8df7f75bc65a261ed21 (patch)
treeb2667b421b7dde7ee8e196f616661baddd307579
parent9156cbeb944db80245766c317f43434b4c1981e5 (diff)
parentb79c9ad864a27aea620254c3eade6d38adcf38f2 (diff)
downloadscala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.tar.gz
scala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.tar.bz2
scala-async-0c5a1ea043c72bbc3568a8df7f75bc65a261ed21.zip
Merge pull request #27 from retronym/topic/typed-transform-2
Typeful transformations
-rw-r--r--.gitignore1
-rw-r--r--build.sbt11
-rw-r--r--project/build.properties2
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala275
-rw-r--r--src/main/scala/scala/async/Async.scala185
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala192
-rw-r--r--src/main/scala/scala/async/AsyncBase.scala23
-rw-r--r--src/main/scala/scala/async/StateMachine.scala12
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala374
-rw-r--r--src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala48
-rw-r--r--src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala9
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsync.scala11
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala8
-rw-r--r--src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala1
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala268
-rw-r--r--src/main/scala/scala/async/internal/AsyncAnalysis.scala94
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala61
-rw-r--r--src/main/scala/scala/async/internal/AsyncId.scala66
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala32
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala177
-rw-r--r--src/main/scala/scala/async/internal/AsyncUtils.scala (renamed from src/main/scala/scala/async/AsyncUtils.scala)2
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala (renamed from src/main/scala/scala/async/ExprBuilder.scala)209
-rw-r--r--src/main/scala/scala/async/internal/FutureSystem.scala (renamed from src/main/scala/scala/async/FutureSystem.scala)83
-rw-r--r--src/main/scala/scala/async/internal/Lifter.scala150
-rw-r--r--src/main/scala/scala/async/internal/StateAssigner.scala (renamed from src/main/scala/scala/async/StateAssigner.scala)4
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala251
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala24
-rw-r--r--src/test/scala/scala/async/neg/LocalClasses0Spec.scala124
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala38
-rw-r--r--src/test/scala/scala/async/package.scala17
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala97
-rw-r--r--src/test/scala/scala/async/run/hygiene/Hygiene.scala3
-rw-r--r--src/test/scala/scala/async/run/ifelse0/IfElse0.scala1
-rw-r--r--src/test/scala/scala/async/run/ifelse0/WhileSpec.scala3
-rw-r--r--src/test/scala/scala/async/run/match0/Match0.scala35
-rw-r--r--src/test/scala/scala/async/run/nesteddef/NestedDef.scala57
-rw-r--r--src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala1
-rw-r--r--src/test/scala/scala/async/run/toughtype/ToughType.scala71
38 files changed, 1629 insertions, 1391 deletions
diff --git a/.gitignore b/.gitignore
index 0c4d130..6bf5f1a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,4 @@ target
.idea
.idea_modules
*.icode
+project/local.sbt \ No newline at end of file
diff --git a/build.sbt b/build.sbt
index c0e062e..d6dc3bb 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,4 +1,4 @@
-scalaVersion := "2.10.1"
+scalaVersion := "2.10.2"
organization := "org.typesafe.async" // TODO new org name under scala-lang.
@@ -8,8 +8,8 @@ version := "1.0.0-SNAPSHOT"
libraryDependencies <++= (scalaVersion) {
sv => Seq(
- "org.scala-lang" % "scala-reflect" % sv,
- "org.scala-lang" % "scala-compiler" % sv % "test"
+ "org.scala-lang" % "scala-reflect" % sv % "provided",
+ "org.scala-lang" % "scala-compiler" % sv % "provided"
)
}
@@ -32,6 +32,8 @@ scalacOptions += "-P:continuations:enable"
scalacOptions ++= Seq("-deprecation", "-unchecked", "-Xlint", "-feature")
+scalacOptions in Test ++= Seq("-Yrangepos")
+
description := "An asynchronous programming facility for Scala, in the spirit of C# await/async"
homepage := Some(url("http://github.com/scala/async"))
@@ -40,6 +42,9 @@ startYear := Some(2012)
licenses +=("Scala license", url("https://github.com/scala/async/blob/master/LICENSE"))
+// Uncomment to disable test compilation.
+// (sources in Test) ~= ((xs: Seq[File]) => xs.filter(f => Seq("TreeInterrogation", "package").exists(f.name.contains)))
+
pomExtra := (
<developers>
<developer>
diff --git a/project/build.properties b/project/build.properties
index 2b9d40c..5e96e96 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -1 +1 @@
-sbt.version=0.12.1 \ No newline at end of file
+sbt.version=0.12.4
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
deleted file mode 100644
index 5b9901d..0000000
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ /dev/null
@@ -1,275 +0,0 @@
-
-/*
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-
-package scala.async
-
-import scala.reflect.macros.Context
-
-private[async] final case class AnfTransform[C <: Context](c: C) {
-
- import c.universe._
-
- val utils = TransformUtils[c.type](c)
-
- import utils._
-
- def apply(tree: Tree): List[Tree] = {
- val unique = uniqueNames(tree)
- // Must prepend the () for issue #31.
- anf.transformToList(Block(List(c.literalUnit.tree), unique))
- }
-
- private def uniqueNames(tree: Tree): Tree = {
- new UniqueNames(tree).transform(tree)
- }
-
- /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name.
- * Only modifies names that appear more than once in the tree.
- *
- * This step is needed to allow us to safely merge blocks during the `inline` transform below.
- */
- private final class UniqueNames(tree: Tree) extends Transformer {
- val repeatedNames: Set[Symbol] = {
- class DuplicateNameTraverser extends AsyncTraverser {
- val result = collection.mutable.Buffer[Symbol]()
-
- override def traverse(tree: Tree) {
- tree match {
- case dt: DefTree => result += dt.symbol
- case _ => super.traverse(tree)
- }
- }
- }
- val dupNameTraverser = new DuplicateNameTraverser
- dupNameTraverser.traverse(tree)
- dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol]
- }
-
- /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
-
- val renamed = collection.mutable.Set[Symbol]()
-
- override def transform(tree: Tree): Tree = {
- tree match {
- case defTree: DefTree if repeatedNames(defTree.symbol) =>
- val trans = super.transform(defTree)
- val origName = defTree.symbol.name
- val sym = defTree.symbol.asInstanceOf[symtab.Symbol]
- val fresh = name.fresh(sym.name.toString)
- sym.name = origName match {
- case _: TermName => symtab.newTermName(fresh)
- case _: TypeName => symtab.newTypeName(fresh)
- }
- renamed += trans.symbol
- val newName = trans.symbol.name
- trans match {
- case ValDef(mods, name, tpt, rhs) =>
- treeCopy.ValDef(trans, mods, newName, tpt, rhs)
- case Bind(name, body) =>
- treeCopy.Bind(trans, newName, body)
- case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
- treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
- case TypeDef(mods, name, tparams, rhs) =>
- treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs))
- // If we were to allow local classes / objects, we would need to rename here.
- case ClassDef(mods, name, tparams, impl) =>
- treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template])
- case ModuleDef(mods, name, impl) =>
- treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template])
- case x => super.transform(x)
- }
- case Ident(name) =>
- if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name)
- else tree
- case Select(fun, name) =>
- if (renamed(tree.symbol)) {
- treeCopy.Select(tree, transform(fun), tree.symbol.name)
- } else super.transform(tree)
- case tt: TypeTree =>
- val tt1 = tt.asInstanceOf[symtab.TypeTree]
- val orig = tt1.original
- if (orig != null) tt1.setOriginal(transform(orig.asInstanceOf[Tree]).asInstanceOf[symtab.Tree])
- super.transform(tt)
- case _ => super.transform(tree)
- }
- }
- }
-
- private object trace {
- private var indent = -1
-
- def indentString = " " * indent
-
- def apply[T](prefix: String, args: Any)(t: => T): T = {
- indent += 1
- def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
- try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
- val result = t
- AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
- result
- } finally {
- indent -= 1
- }
- }
- }
-
- private object inline {
- def transformToList(tree: Tree): List[Tree] = trace("inline", tree) {
- val stats :+ expr = anf.transformToList(tree)
- expr match {
- case Apply(fun, args) if isAwait(fun) =>
- val valDef = defineVal(name.await, expr, tree.pos)
- stats :+ valDef :+ Ident(valDef.name)
-
- case If(cond, thenp, elsep) =>
- // if type of if-else is Unit don't introduce assignment,
- // but add Unit value to bring it into form expected by async transform
- if (expr.tpe =:= definitions.UnitTpe) {
- stats :+ expr :+ Literal(Constant(()))
- } else {
- val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
- def branchWithAssign(orig: Tree) = orig match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
- case _ => Assign(Ident(varDef.name), orig)
- }
- val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
- stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
- }
-
- case Match(scrut, cases) =>
- // if type of match is Unit don't introduce assignment,
- // but add Unit value to bring it into form expected by async transform
- if (expr.tpe =:= definitions.UnitTpe) {
- stats :+ expr :+ Literal(Constant(()))
- }
- else {
- val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
- val casesWithAssign = cases map {
- case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) =>
- attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))))
- case cd@CaseDef(pat, guard, body) =>
- attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body)))
- }
- val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign))
- stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
- }
- case _ =>
- stats :+ expr
- }
- }
-
- def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList
-
- def transformToBlock(tree: Tree): Block = transformToList(tree) match {
- case stats :+ expr => Block(stats, expr)
- }
-
- private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
- val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp))
- vd.setPos(pos)
- vd
- }
- }
-
- private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
- val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs)
- vd.setPos(pos)
- vd
- }
-
- private object anf {
-
- private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) {
- val containsAwait = tree exists isAwait
- if (!containsAwait) {
- List(tree)
- } else tree match {
- case Select(qual, sel) =>
- val stats :+ expr = inline.transformToList(qual)
- stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))
-
- case Applied(fun, targs, argss) if argss.nonEmpty =>
- // we an assume that no await call appears in a by-name argument position,
- // this has already been checked.
- val funStats :+ simpleFun = inline.transformToList(fun)
- def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$")
- val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
- mapArgumentss[List[Tree]](fun, argss) {
- case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr)
- case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline`
- case Arg(expr, _, argName) =>
- inline.transformToList(expr) match {
- case stats :+ expr1 =>
- val valDef = defineVal(argName, expr1, expr.pos)
- (stats :+ valDef, Ident(valDef.name))
- }
- }
- val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs)
- val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol)
- funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply)
- case Block(stats, expr) =>
- inline.transformToList(stats :+ expr)
-
- case ValDef(mods, name, tpt, rhs) =>
- if (rhs exists isAwait) {
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol))
- } else List(tree)
-
- case Assign(lhs, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy(tree)(Assign(lhs, expr))
-
- case If(cond, thenp, elsep) =>
- val condStats :+ condExpr = inline.transformToList(cond)
- val thenBlock = inline.transformToBlock(thenp)
- val elseBlock = inline.transformToBlock(elsep)
- // Typechecking with `condExpr` as the condition fails if the condition
- // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
- // we rely on this call to `typeCheck` descending into the branches.
- // But, we can get away with typechecking a throwaway `If` tree with the
- // original scrutinee and the new branches, and setting that type on
- // the real `If` tree.
- val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe
- condStats :+
- attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType)
-
- case Match(scrut, cases) =>
- val scrutStats :+ scrutExpr = inline.transformToList(scrut)
- val caseDefs = cases map {
- case CaseDef(pat, guard, body) =>
- // extract local variables for all names bound in `pat`, and rewrite `body`
- // to refer to these.
- // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
- val block = inline.transformToBlock(body)
- val (valDefs, mappings) = (pat collect {
- case b@Bind(name, _) =>
- val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
- val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
- (vd, (b.symbol, newName))
- }).unzip
- val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
- attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
- }
- // Refer to comments the translation of `If` above.
- val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe
- val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe)
- scrutStats :+ typedMatch
-
- case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
-
- case TypeApply(fun, targs) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol))
-
- case _ =>
- List(tree)
- }
- }
- }
-}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
deleted file mode 100644
index 35d3687..0000000
--- a/src/main/scala/scala/async/Async.scala
+++ /dev/null
@@ -1,185 +0,0 @@
-/*
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-
-package scala.async
-
-import scala.language.experimental.macros
-import scala.reflect.macros.Context
-import scala.reflect.internal.annotations.compileTimeOnly
-
-object Async extends AsyncBase {
-
- import scala.concurrent.Future
-
- lazy val futureSystem = ScalaConcurrentFutureSystem
- type FS = ScalaConcurrentFutureSystem.type
-
- def async[T](body: T) = macro asyncImpl[T]
-
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
-}
-
-object AsyncId extends AsyncBase {
- lazy val futureSystem = IdentityFutureSystem
- type FS = IdentityFutureSystem.type
-
- def async[T](body: T) = macro asyncImpl[T]
-
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body)
-}
-
-/**
- * A base class for the `async` macro. Subclasses must provide:
- *
- * - Concrete types for a given future system
- * - Tree manipulations to create and complete the equivalent of Future and Promise
- * in that system.
- * - The `async` macro declaration itself, and a forwarder for the macro implementation.
- * (The latter is temporarily needed to workaround bug SI-6650 in the macro system)
- *
- * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`.
- */
-abstract class AsyncBase {
- self =>
-
- type FS <: FutureSystem
- val futureSystem: FS
-
- /**
- * A call to `await` must be nested in an enclosing `async` block.
- *
- * A call to `await` does not block the current thread, rather it is a delimiter
- * used by the enclosing `async` macro. Code following the `await`
- * call is executed asynchronously, when the argument of `await` has been completed.
- *
- * @param awaitable the future from which a value is awaited.
- * @tparam T the type of that value.
- * @return the value.
- */
- @compileTimeOnly("`await` must be enclosed in an `async` block")
- def await[T](awaitable: futureSystem.Fut[T]): T = ???
-
- protected[async] def fallbackEnabled = false
-
- def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
- import c.universe._
-
- val analyzer = AsyncAnalysis[c.type](c, this)
- val utils = TransformUtils[c.type](c)
- import utils.{name, defn}
-
- analyzer.reportUnsupportedAwaits(body.tree)
-
- // Transform to A-normal form:
- // - no await calls in qualifiers or arguments,
- // - if/match only used in statement position.
- val anfTree: Block = {
- val anf = AnfTransform[c.type](c)
- val restored = utils.restorePatternMatchingFunctions(body.tree)
- val stats1 :+ expr1 = anf(restored)
- val block = Block(stats1, expr1)
- c.typeCheck(block).asInstanceOf[Block]
- }
-
- // Analyze the block to find locals that will be accessed from multiple
- // states of our generated state machine, e.g. a value assigned before
- // an `await` and read afterwards.
- val renameMap: Map[Symbol, TermName] = {
- analyzer.defTreesUsedInSubsequentStates(anfTree).map {
- vd =>
- (vd.symbol, name.fresh(vd.name.toTermName))
- }.toMap
- }
-
- val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
- import builder.futureSystemOps
- val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
- import asyncBlock.asyncStates
- logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
-
- // Important to retain the original declaration order here!
- val localVarTrees = anfTree.collect {
- case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
- utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
- case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol =>
- DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
- }
-
- val onCompleteHandler = {
- Function(
- List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)),
- asyncBlock.onCompleteHandler)
- }
- val resumeFunTree = asyncBlock.resumeFunTree[T]
-
- val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
-
- lazy val stateMachine: ClassDef = {
- val body: List[Tree] = {
- val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
- val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
- val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
- val applyDefDef: DefDef = {
- val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
- val applyBody = asyncBlock.onCompleteHandler
- DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody)
- }
- val apply0DefDef: DefDef = {
- // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
- // See SI-1247 for the the optimization that avoids creatio
- val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
- val applyBody = asyncBlock.onCompleteHandler
- DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
- }
- List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
- }
- val template = {
- Template(List(stateMachineType), emptyValDef, body)
- }
- ClassDef(NoMods, name.stateMachineT, Nil, template)
- }
-
- def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
-
- val code: c.Expr[futureSystem.Fut[T]] = {
- val isSimple = asyncStates.size == 1
- val tree =
- if (isSimple)
- Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
- else {
- Block(List[Tree](
- stateMachine,
- ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)),
- futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
- ),
- futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
- }
- c.Expr[futureSystem.Fut[T]](tree)
- }
-
- AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
- code
- }
-
- def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
- def location = try {
- c.macroApplication.pos.source.path
- } catch {
- case _: UnsupportedOperationException =>
- c.macroApplication.pos.toString
- }
-
- AsyncUtils.vprintln(s"In file '$location':")
- AsyncUtils.vprintln(s"${c.macroApplication}")
- AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
- states foreach (s => AsyncUtils.vprintln(s))
- }
-}
-
-/** Internal class used by the `async` macro; should not be manually extended by client code */
-abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
- def result$async: Result
-
- def execContext$async: EC
-}
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
deleted file mode 100644
index 4f55f1b..0000000
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ /dev/null
@@ -1,192 +0,0 @@
-/*
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-
-package scala.async
-
-import scala.reflect.macros.Context
-import scala.collection.mutable
-
-private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
- import c.universe._
-
- val utils = TransformUtils[c.type](c)
-
- import utils._
-
- /**
- * Analyze the contents of an `async` block in order to:
- * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
- *
- * Must be called on the original tree, not on the ANF transformed tree.
- */
- def reportUnsupportedAwaits(tree: Tree): Boolean = {
- val analyzer = new UnsupportedAwaitAnalyzer
- analyzer.traverse(tree)
- analyzer.hasUnsupportedAwaits
- }
-
- /**
- * Analyze the contents of an `async` block in order to:
- * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
- * on whether or not they are accessed only from a single state.
- *
- * Must be called on the ANF transformed tree.
- */
- def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
- val analyzer = new AsyncDefinitionUseAnalyzer
- analyzer.traverse(tree)
- val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
- liftable
- }
-
- private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
- var hasUnsupportedAwaits = false
-
- override def nestedClass(classDef: ClassDef) {
- val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
- if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
- // do not allow local class definitions, because of SI-5467 (specific to case classes, though)
- if (classDef.symbol.asClass.isCaseClass)
- c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
- }
- }
-
- override def nestedModule(module: ModuleDef) {
- if (!reportUnsupportedAwait(module, "nested object")) {
- // local object definitions lead to spurious type errors (because of resetAllAttrs?)
- c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
- }
- }
-
- override def nestedMethod(module: DefDef) {
- reportUnsupportedAwait(module, "nested method")
- }
-
- override def byNameArgument(arg: Tree) {
- reportUnsupportedAwait(arg, "by-name argument")
- }
-
- override def function(function: Function) {
- reportUnsupportedAwait(function, "nested function")
- }
-
- override def patMatFunction(tree: Match) {
- reportUnsupportedAwait(tree, "nested function")
- }
-
- override def traverse(tree: Tree) {
- def containsAwait = tree exists isAwait
- tree match {
- case Try(_, _, _) if containsAwait =>
- reportUnsupportedAwait(tree, "try/catch")
- super.traverse(tree)
- case Return(_) =>
- c.abort(tree.pos, "return is illegal within a async block")
- case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
- c.abort(tree.pos, "lazy vals are illegal within an async block")
- case _ =>
- super.traverse(tree)
- }
- }
-
- /**
- * @return true, if the tree contained an unsupported await.
- */
- private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
- val badAwaits: List[RefTree] = tree collect {
- case rt: RefTree if isAwait(rt) => rt
- }
- badAwaits foreach {
- tree =>
- reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
- }
- badAwaits.nonEmpty
- }
-
- private def reportError(pos: Position, msg: String) {
- hasUnsupportedAwaits = true
- if (!asyncBase.fallbackEnabled)
- c.error(pos, msg)
- }
- }
-
- private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
- private var chunkId = 0
-
- private def nextChunk() = chunkId += 1
-
- private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
-
- val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
- val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
-
- override def nestedMethod(defDef: DefDef) {
- nestedMethodsToLift += defDef
- markReferencedVals(defDef)
- }
-
- override def function(function: Function) {
- markReferencedVals(function)
- }
-
- override def patMatFunction(tree: Match) {
- markReferencedVals(tree)
- }
-
- private def markReferencedVals(tree: Tree) {
- tree foreach {
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) =>
- valDefsToLift += vd // lift all vals referred to by nested functions.
- case _ =>
- }
- case _ =>
- }
- }
-
- override def traverse(tree: Tree) = {
- tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
- traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
- traverseChunks(selector :: cases)
- case LabelDef(name, params, rhs) if rhs exists isAwait =>
- traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
- super.traverse(tree)
- nextChunk()
- case vd: ValDef =>
- super.traverse(tree)
- valDefChunkId += (vd.symbol -> (vd -> chunkId))
- val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
- if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
- case as: Assign =>
- if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
-
- // TODO test the orElse case, try to remove the restriction.
- val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
- valDefsToLift += vd
- }
- super.traverse(tree)
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) if defChunkId != chunkId =>
- valDefsToLift += vd
- case _ =>
- }
- super.traverse(tree)
- case _ => super.traverse(tree)
- }
- }
-
- private def traverseChunks(trees: List[Tree]) {
- trees.foreach {
- t => traverse(t); nextChunk()
- }
- }
- }
-
-}
diff --git a/src/main/scala/scala/async/AsyncBase.scala b/src/main/scala/scala/async/AsyncBase.scala
new file mode 100644
index 0000000..ff04a57
--- /dev/null
+++ b/src/main/scala/scala/async/AsyncBase.scala
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+
+import scala.language.experimental.macros
+import scala.reflect.macros.Context
+import scala.concurrent.{Future, ExecutionContext}
+import scala.async.internal.{AsyncBase, ScalaConcurrentFutureSystem}
+
+object Async extends AsyncBase {
+ type FS = ScalaConcurrentFutureSystem.type
+ val futureSystem: FS = ScalaConcurrentFutureSystem
+
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
+}
diff --git a/src/main/scala/scala/async/StateMachine.scala b/src/main/scala/scala/async/StateMachine.scala
new file mode 100644
index 0000000..823df71
--- /dev/null
+++ b/src/main/scala/scala/async/StateMachine.scala
@@ -0,0 +1,12 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+
+/** Internal class used by the `async` macro; should not be manually extended by client code */
+abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
+ def result: Result
+
+ def execContext: EC
+}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
deleted file mode 100644
index ebd546f..0000000
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ /dev/null
@@ -1,374 +0,0 @@
-/*
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-package scala.async
-
-import scala.reflect.macros.Context
-import reflect.ClassTag
-
-/**
- * Utilities used in both `ExprBuilder` and `AnfTransform`.
- */
-private[async] final case class TransformUtils[C <: Context](c: C) {
-
- import c.universe._
-
- object name {
- def suffix(string: String) = string + "$async"
-
- def suffixedName(prefix: String) = newTermName(suffix(prefix))
-
- val state = suffixedName("state")
- val result = suffixedName("result")
- val resume = suffixedName("resume")
- val execContext = suffixedName("execContext")
- val stateMachine = newTermName(fresh("stateMachine"))
- val stateMachineT = stateMachine.toTypeName
- val apply = newTermName("apply")
- val applyOrElse = newTermName("applyOrElse")
- val tr = newTermName("tr")
- val matchRes = "matchres"
- val ifRes = "ifres"
- val await = "await"
- val bindSuffix = "$bind"
-
- def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
-
- def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$")
- }
-
- def defaultValue(tpe: Type): Literal = {
- val defaultValue: Any =
- if (tpe <:< definitions.BooleanTpe) false
- else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
- else if (tpe <:< definitions.AnyValTpe) 0
- else null
- Literal(Constant(defaultValue))
- }
-
- def isAwait(fun: Tree) =
- fun.symbol == defn.Async_await
-
- /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node
- * referring to the corresponding new name
- */
- def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = {
- val renamer = new Transformer {
- override def transform(tree: Tree) = tree match {
- case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_))
- case tt: TypeTree if tt.original != EmptyTree && tt.original != null =>
- // We also have to apply our renaming transform on originals of TypeTrees.
- // TODO 2.10.1 Can we find a cleaner way?
- val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val tt1 = tt.asInstanceOf[symTab.TypeTree]
- tt1.setOriginal(transform(tt.original).asInstanceOf[symTab.Tree])
- super.transform(tree)
- case _ => super.transform(tree)
- }
- }
- renamer.transform(tree)
- }
-
- /** Descends into the regions of the tree that are subject to the
- * translation to a state machine by `async`. When a nested template,
- * function, or by-name argument is encountered, the descent stops,
- * and `nestedClass` etc are invoked.
- */
- trait AsyncTraverser extends Traverser {
- def nestedClass(classDef: ClassDef) {
- }
-
- def nestedModule(module: ModuleDef) {
- }
-
- def nestedMethod(module: DefDef) {
- }
-
- def byNameArgument(arg: Tree) {
- }
-
- def function(function: Function) {
- }
-
- def patMatFunction(tree: Match) {
- }
-
- override def traverse(tree: Tree) {
- tree match {
- case cd: ClassDef => nestedClass(cd)
- case md: ModuleDef => nestedModule(md)
- case dd: DefDef => nestedMethod(dd)
- case fun: Function => function(fun)
- case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
- case Applied(fun, targs, argss) if argss.nonEmpty =>
- val isInByName = isByName(fun)
- for ((args, i) <- argss.zipWithIndex) {
- for ((arg, j) <- args.zipWithIndex) {
- if (!isInByName(i, j)) traverse(arg)
- else byNameArgument(arg)
- }
- }
- traverse(fun)
- case _ => super.traverse(tree)
- }
- }
- }
-
- private lazy val Boolean_ShortCircuits: Set[Symbol] = {
- import definitions.BooleanClass
- def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
- val Boolean_&& = BooleanTermMember("&&")
- val Boolean_|| = BooleanTermMember("||")
- Set(Boolean_&&, Boolean_||)
- }
-
- def isByName(fun: Tree): ((Int, Int) => Boolean) = {
- if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
- else {
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
- val byNamess = paramss.map(_.map(_.isByNameParam))
- (i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
- }
- }
- def argName(fun: Tree): ((Int, Int) => String) = {
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
- val namess = paramss.map(_.map(_.name.toString))
- (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
- }
-
- object Applied {
- val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
- object treeInfo extends {
- val global: symtab.type = symtab
- } with reflect.internal.TreeInfo
-
- def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = {
- val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree]
- Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]]))
- }
- }
-
- def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
- case Block(stats, expr) => (stats, expr)
- case _ => (List(tree), Literal(Constant(())))
- }
-
- def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = {
- ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType))
- }
-
- def emptyConstructor: DefDef = {
- val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
- DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree))
- }
-
- def applied(className: String, types: List[Type]): AppliedTypeTree =
- AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_)))
-
- object defn {
- def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
- c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
- }
-
- def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
-
- def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
- self.splice.apply(arg.splice)
- }
-
- def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
- self.splice == other.splice
- }
-
- def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
- self.splice.get
- }
-
- val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
- val Try_isFailure = methodSym(reify((null: scala.util.Try[Any]).isFailure))
-
- val TryClass = c.mirror.staticClass("scala.util.Try")
- val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
- val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
-
- private def asyncMember(name: String) = {
- val asyncMod = c.mirror.staticClass("scala.async.AsyncBase")
- val tpe = asyncMod.asType.toType
- tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
- }
-
- val Async_await = asyncMember("await")
- }
-
- /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
- private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree)
- tree2.collect {
- case s: SymTree if s.symbol.isMethod => s.symbol
- }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
- }
-
- /**
- * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like:
- * we don't want to copy types and symbols to the new trees in some cases.
- *
- * Instead, we just copy positions and attachments.
- */
- def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = {
- tree.setPos(orig.pos)
- for (att <- orig.attachments.all)
- tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass))
- tree
- }
-
- def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) =
- new ResetInternalAttrs(internalSyms.toSet).transform(tree)
-
- /**
- * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]]
- *
- * A transformer which resets symbol and tpe fields of all nodes in a given tree,
- * with special treatment of:
- * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset
- * to empty if it started out empty or refers to local symbols (which are erased).
- * `TypeApply` nodes: are deleted if type arguments end up reverted to empty
- *
- * `This` and `Ident` nodes referring to an external symbol are ''not'' reset.
- */
- private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer {
-
- import language.existentials
-
- override def transform(tree: Tree): Tree = super.transform {
- def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol)
-
- tree match {
- case tpt: TypeTree => resetTypeTree(tpt)
- case TypeApply(fn, args)
- if args map transform exists (_.isEmpty) => transform(fn)
- case EmptyTree => tree
- case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block
- case _ => resetTree(tree)
- }
- }
-
- private def resetTypeTree(tpt: TypeTree): Tree = {
- if (tpt.original != null)
- transform(tpt.original)
- else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) {
- val dupl = tpt.duplicate
- dupl.tpe = null
- dupl
- }
- else tpt
- }
-
- private def resetTree(tree: Tree): Tree = {
- val hasSymbol: Boolean = {
- val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}]
- reflectInternalTree.hasSymbol
- }
- val dupl = tree.duplicate
- if (hasSymbol)
- dupl.symbol = NoSymbol
- dupl.tpe = null
- dupl
- }
- }
-
- /**
- * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }`
- * with `Match(EmptyTree, cases`.
- *
- * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees.
- */
- // TODO Reference JIRA issue.
- final def restorePatternMatchingFunctions(tree: Tree) =
- RestorePatternMatchingFunctions transform tree
-
- private object RestorePatternMatchingFunctions extends Transformer {
-
- import language.existentials
- val DefaultCaseName: TermName = "defaultCase$"
-
- override def transform(tree: Tree): Tree = {
- val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet]
- def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC
-
- /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know
- * whether the user provided cases are exhaustive. */
- def isSyntheticDefaultCase(cdef: CaseDef) = cdef match {
- case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true
- case _ => false
- }
- tree match {
- case Block(
- (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil,
- Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) =>
- val restored = (body collectFirst {
- case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) =>
- val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef))
- val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]]
- Match(EmptyTree, transformedCases)
- }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body"))
- restored
- case t => super.transform(t)
- }
- }
- }
-
- def isSafeToInline(tree: Tree) = {
- val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
- object treeInfo extends {
- val global: symtab.type = symtab
- } with reflect.internal.TreeInfo
- val castTree = tree.asInstanceOf[symtab.Tree]
- treeInfo.isExprSafeToInline(castTree)
- }
-
- /** Map a list of arguments to:
- * - A list of argument Trees
- * - A list of auxillary results.
- *
- * The function unwraps and rewraps the `arg :_*` construct.
- *
- * @param args The original argument trees
- * @param f A function from argument (with '_*' unwrapped) and argument index to argument.
- * @tparam A The type of the auxillary result
- */
- private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
- args match {
- case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) =>
- val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip
- val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos)
- (a, exprs)
- case args =>
- args.zipWithIndex.map(f.tupled).unzip
- }
- }
-
- case class Arg(expr: Tree, isByName: Boolean, argName: String)
-
- /**
- * Transform a list of argument lists, producing the transformed lists, and lists of auxillary
- * results.
- *
- * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will
- * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`.
- *
- * @param fun The function being applied
- * @param argss The argument lists
- * @return (auxillary results, mapped argument trees)
- */
- def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = {
- val isByNamess: (Int, Int) => Boolean = isByName(fun)
- val argNamess: (Int, Int) => String = argName(fun)
- argss.zipWithIndex.map { case (args, i) =>
- mapArguments[A](args) {
- (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j)))
- }
- }.unzip
- }
-}
diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
index a669cfa..1a6ac87 100644
--- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
+++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
@@ -9,8 +9,9 @@ import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.util.continuations._
+import scala.async.internal.{AsyncMacro, AsyncUtils}
-trait AsyncBaseWithCPSFallback extends AsyncBase {
+trait AsyncBaseWithCPSFallback extends internal.AsyncBase {
/* Fall-back for `await` using CPS plugin.
*
@@ -22,27 +23,34 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
/* Implements `async { ... }` using the CPS plugin.
*/
- protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
- def lookupMember(name: String) = {
- val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback")
+ def lookupClassMember(clazz: String, name: String) = {
+ val asyncTrait = c.mirror.staticClass(clazz)
val tpe = asyncTrait.asType.toType
- tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name")
+ }
+ def lookupObjectMember(clazz: String, name: String) = {
+ val moduleClass = c.mirror.staticModule(clazz).moduleClass
+ val tpe = moduleClass.asType.toType
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name")
}
AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl")
- val utils = TransformUtils[c.type](c)
- val futureSystemOps = futureSystem.mkOps(c)
- val awaitSym = utils.defn.Async_await
- val awaitFallbackSym = lookupMember("awaitFallback")
+ val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val futureSystemOps = futureSystem.mkOps(symTab)
+ val awaitSym = lookupObjectMember("scala.async.Async", "await")
+ val awaitFallbackSym = lookupClassMember("scala.async.continuations.AsyncBaseWithCPSFallback", "awaitFallback")
// replace `await` invocations with `awaitFallback` invocations
val awaitReplacer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym =>
- val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe)))
+ val typeApp = treeCopy.TypeApply(fun, atPos(tree.pos)(Ident(awaitFallbackSym)), List(atPos(tree.pos)(TypeTree(futArgTpt.tpe))))
treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)))
case _ =>
super.transform(tree)
@@ -60,10 +68,12 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
}.asInstanceOf[Future[T]]
*/
+ def spawn(expr: Tree) = futureSystemOps.spawn(expr.asInstanceOf[futureSystemOps.universe.Tree], execContext.tree.asInstanceOf[futureSystemOps.universe.Tree]).asInstanceOf[Tree]
+
val bodyWithFuture = {
val tree = bodyWithAwaitFallback match {
- case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr))
- case expr => futureSystemOps.spawn(expr)
+ case Block(stmts, expr) => Block(stmts, spawn(expr))
+ case expr => spawn(expr)
}
c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate))
}
@@ -71,20 +81,22 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify {
reset { bodyWithFuture.splice }
}
- val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset)
+ val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset.asInstanceOf[futureSystemOps.universe.Expr[futureSystem.Fut[Any]]]).asInstanceOf[c.Expr[futureSystem.Fut[T]]]
AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}")
bodyWithCast
}
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl")
- val analyzer = AsyncAnalysis[c.type](c, this)
+ val asyncMacro = AsyncMacro(c, futureSystem)
- if (!analyzer.reportUnsupportedAwaits(body.tree))
- super.asyncImpl[T](c)(body) // no unsupported awaits
+ if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled))
+ super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits
else
- cpsBasedAsyncImpl[T](c)(body) // fallback to CPS
+ cpsBasedAsyncImpl[T](c)(body)(execContext) // fallback to CPS
}
}
diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
index fe6e1a6..e0da5aa 100644
--- a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
+++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
@@ -13,8 +13,13 @@ import scala.concurrent.Future
trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback
object AsyncWithCPSFallback extends AsyncWithCPSFallback {
+ import scala.concurrent.{ExecutionContext, Future}
- def async[T](body: T) = macro asyncImpl[T]
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
index 922d1ac..2003082 100644
--- a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
@@ -8,14 +8,17 @@ package continuations
import scala.language.experimental.macros
import scala.reflect.macros.Context
-import scala.concurrent.Future
+import scala.concurrent.{ExecutionContext, Future}
trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback
object CPSBasedAsync extends CPSBasedAsync {
- def async[T](body: T) = macro asyncImpl[T]
-
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
index 4e8ec80..a350704 100644
--- a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
@@ -15,7 +15,9 @@ import scala.util.continuations._
*/
trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback {
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] =
- super.cpsBasedAsyncImpl[T](c)(body)
-
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
+ super.cpsBasedAsyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala
index 018ad05..f864ad6 100644
--- a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala
+++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala
@@ -7,6 +7,7 @@ package continuations
import scala.util.continuations._
import scala.concurrent.{Future, Promise, ExecutionContext}
+import scala.async.internal.ScalaConcurrentFutureSystem
trait ScalaConcurrentCPSFallback {
self: AsyncBaseWithCPSFallback =>
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
new file mode 100644
index 0000000..6aeaba3
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -0,0 +1,268 @@
+
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.tools.nsc.Global
+import scala.Predef._
+
+private[async] trait AnfTransform {
+ self: AsyncMacro =>
+
+ import global._
+ import reflect.internal.Flags._
+
+ def anfTransform(tree: Tree): Block = {
+ // Must prepend the () for issue #31.
+ val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe)
+
+ new SelectiveAnfTransform().transform(block)
+ }
+
+ sealed abstract class AnfMode
+
+ case object Anf extends AnfMode
+
+ case object Linearizing extends AnfMode
+
+ final class SelectiveAnfTransform extends MacroTypingTransformer {
+ var mode: AnfMode = Anf
+
+ def blockToList(tree: Tree): List[Tree] = tree match {
+ case Block(stats, expr) => stats :+ expr
+ case t => t :: Nil
+ }
+
+ def listToBlock(trees: List[Tree]): Block = trees match {
+ case trees @ (init :+ last) =>
+ val pos = trees.map(_.pos).reduceLeft(_ union _)
+ Block(init, last).setType(last.tpe).setPos(pos)
+ }
+
+ override def transform(tree: Tree): Block = {
+ def anfLinearize: Block = {
+ val trees: List[Tree] = mode match {
+ case Anf => anf._transformToList(tree)
+ case Linearizing => linearize._transformToList(tree)
+ }
+ listToBlock(trees)
+ }
+ tree match {
+ case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef =>
+ atOwner(tree.symbol)(anfLinearize)
+ case _: ModuleDef =>
+ atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize)
+ case _ =>
+ anfLinearize
+ }
+ }
+
+ private object linearize {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Linearizing; blockToList(transform(tree))
+ }
+
+ def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree))
+
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val stats :+ expr = anf.transformToList(tree)
+ expr match {
+ case Apply(fun, args) if isAwait(fun) =>
+ val valDef = defineVal(name.await, expr, tree.pos)
+ stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe).setPos(tree.pos)
+
+ case If(cond, thenp, elsep) =>
+ // if type of if-else is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ } else {
+ val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
+ def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) {
+ def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe)
+ orig match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
+ case _ => Assign(Ident(varDef.symbol), cast(orig))
+ }
+ }
+ val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe)
+ stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setType(tree.tpe).setPos(tree.pos)
+ }
+
+ case Match(scrut, cases) =>
+ // if type of match is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ }
+ else {
+ val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
+ def typedAssign(lhs: Tree) =
+ localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe)))
+ val casesWithAssign = cases map {
+ case cd@CaseDef(pat, guard, body) =>
+ val newBody = body match {
+ case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)).setType(definitions.UnitTpe)
+ case _ => typedAssign(body)
+ }
+ treeCopy.CaseDef(cd, pat, guard, newBody).setType(definitions.UnitTpe)
+ }
+ val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign).setType(definitions.UnitTpe)
+ require(matchWithAssign.tpe != null, matchWithAssign)
+ stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol).setPos(tree.pos).setType(tree.tpe)
+ }
+ case _ =>
+ stats :+ expr
+ }
+ }
+
+ private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp)
+ ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos)
+ }
+ }
+
+ private object trace {
+ private var indent = -1
+
+ def indentString = " " * indent
+
+ def apply[T](args: Any)(t: => T): T = {
+ def prefix = mode.toString.toLowerCase
+ indent += 1
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
+ try {
+ AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ val result = t
+ AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
+ result
+ } finally {
+ indent -= 1
+ }
+ }
+ }
+
+ private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe)
+ changeOwner(lhs, currentOwner, sym)
+ ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos)
+ }
+
+ private object anf {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Anf; blockToList(transform(tree))
+ }
+
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val containsAwait = tree exists isAwait
+ if (!containsAwait) {
+ List(tree)
+ } else tree match {
+ case Select(qual, sel) =>
+ val stats :+ expr = linearize.transformToList(qual)
+ stats :+ treeCopy.Select(tree, expr, sel)
+
+ case Throw(expr) =>
+ val stats :+ expr1 = linearize.transformToList(expr)
+ stats :+ treeCopy.Throw(tree, expr1)
+
+ case Typed(expr, tpt) =>
+ val stats :+ expr1 = linearize.transformToList(expr)
+ stats :+ treeCopy.Typed(tree, expr1, tpt)
+
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ // we an assume that no await call appears in a by-name argument position,
+ // this has already been checked.
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
+ mapArgumentss[List[Tree]](fun, argss) {
+ case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr)
+ case Arg(expr, _, argName) =>
+ linearize.transformToList(expr) match {
+ case stats :+ expr1 =>
+ val valDef = defineVal(argName, expr1, expr1.pos)
+ require(valDef.tpe != null, valDef)
+ val stats1 = stats :+ valDef
+ (stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol))))
+ }
+ }
+
+ def copyApplied(tree: Tree, depth: Int): Tree = {
+ tree match {
+ case TypeApply(_, targs) => treeCopy.TypeApply(tree, simpleFun, targs)
+ case _ if depth == 0 => simpleFun
+ case Apply(fun, args) =>
+ val newTypedArgs = map2(args.map(_.pos), argExprss(depth - 1))((pos, arg) => localTyper.typedPos(pos)(arg))
+ treeCopy.Apply(tree, copyApplied(fun, depth - 1), newTypedArgs)
+ }
+ }
+
+ val typedNewApply = copyApplied(tree, treeInfo.dissectApplied(tree).applyDepth)
+
+ funStats ++ argStatss.flatten.flatten :+ typedNewApply
+
+ case Block(stats, expr) =>
+ (stats :+ expr).flatMap(linearize.transformToList)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ if (rhs exists isAwait) {
+ val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs))
+ stats.foreach(changeOwner(_, currOwner, currOwner.owner))
+ stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
+ } else List(tree)
+
+ case Assign(lhs, rhs) =>
+ val stats :+ expr = linearize.transformToList(rhs)
+ stats :+ treeCopy.Assign(tree, lhs, expr)
+
+ case If(cond, thenp, elsep) =>
+ val condStats :+ condExpr = linearize.transformToList(cond)
+ val thenBlock = linearize.transformToBlock(thenp)
+ val elseBlock = linearize.transformToBlock(elsep)
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock)
+ condStats :+ iff
+
+ case Match(scrut, cases) =>
+ val scrutStats :+ scrutExpr = linearize.transformToList(scrut)
+ val caseDefs = cases map {
+ case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
+ val block = linearize.transformToBlock(body)
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos)
+ (vd, (b.symbol, vd.symbol))
+ }).unzip
+ val (from, to) = mappings.unzip
+ val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block]
+ val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1)
+ treeCopy.CaseDef(tree, pat, guard, newBlock)
+ }
+ // Refer to comments the translation of `If` above.
+ val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs)
+ scrutStats :+ typedMatch
+
+ case LabelDef(name, params, rhs) =>
+ List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+
+ case TypeApply(fun, targs) =>
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ funStats :+ treeCopy.TypeApply(tree, simpleFun, targs)
+
+ case _ =>
+ List(tree)
+ }
+ }
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
new file mode 100644
index 0000000..122109e
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.reflect.macros.Context
+import scala.collection.mutable
+
+trait AsyncAnalysis {
+ self: AsyncMacro =>
+
+ import global._
+
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ *
+ * Must be called on the original tree, not on the ANF transformed tree.
+ */
+ def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = {
+ val analyzer = new UnsupportedAwaitAnalyzer(report)
+ analyzer.traverse(tree)
+ analyzer.hasUnsupportedAwaits
+ }
+
+ private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser {
+ var hasUnsupportedAwaits = false
+
+ override def nestedClass(classDef: ClassDef) {
+ val kind = if (classDef.symbol.isTrait) "trait" else "class"
+ reportUnsupportedAwait(classDef, s"nested ${kind}")
+ }
+
+ override def nestedModule(module: ModuleDef) {
+ reportUnsupportedAwait(module, "nested object")
+ }
+
+ override def nestedMethod(defDef: DefDef) {
+ reportUnsupportedAwait(defDef, "nested method")
+ }
+
+ override def byNameArgument(arg: Tree) {
+ reportUnsupportedAwait(arg, "by-name argument")
+ }
+
+ override def function(function: Function) {
+ reportUnsupportedAwait(function, "nested function")
+ }
+
+ override def patMatFunction(tree: Match) {
+ reportUnsupportedAwait(tree, "nested function")
+ }
+
+ override def traverse(tree: Tree) {
+ def containsAwait = tree exists isAwait
+ tree match {
+ case Try(_, _, _) if containsAwait =>
+ reportUnsupportedAwait(tree, "try/catch")
+ super.traverse(tree)
+ case Return(_) =>
+ abort(tree.pos, "return is illegal within a async block")
+ case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
+ // TODO lift this restriction
+ abort(tree.pos, "lazy vals are illegal within an async block")
+ case CaseDef(_, guard, _) if guard exists isAwait =>
+ // TODO lift this restriction
+ reportUnsupportedAwait(tree, "pattern guard")
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ /**
+ * @return true, if the tree contained an unsupported await.
+ */
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
+ val badAwaits: List[RefTree] = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ badAwaits.nonEmpty
+ }
+
+ private def reportError(pos: Position, msg: String) {
+ hasUnsupportedAwaits = true
+ if (report)
+ abort(pos, msg)
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
new file mode 100644
index 0000000..ca06039
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.reflect.internal.annotations.compileTimeOnly
+import scala.reflect.macros.Context
+
+/**
+ * A base class for the `async` macro. Subclasses must provide:
+ *
+ * - Concrete types for a given future system
+ * - Tree manipulations to create and complete the equivalent of Future and Promise
+ * in that system.
+ * - The `async` macro declaration itself, and a forwarder for the macro implementation.
+ * (The latter is temporarily needed to workaround bug SI-6650 in the macro system)
+ *
+ * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`.
+ */
+abstract class AsyncBase {
+ self =>
+
+ type FS <: FutureSystem
+ val futureSystem: FS
+
+ /**
+ * A call to `await` must be nested in an enclosing `async` block.
+ *
+ * A call to `await` does not block the current thread, rather it is a delimiter
+ * used by the enclosing `async` macro. Code following the `await`
+ * call is executed asynchronously, when the argument of `await` has been completed.
+ *
+ * @param awaitable the future from which a value is awaited.
+ * @tparam T the type of that value.
+ * @return the value.
+ */
+ @compileTimeOnly("`await` must be enclosed in an `async` block")
+ def await[T](awaitable: futureSystem.Fut[T]): T = ???
+
+ protected[async] def fallbackEnabled = false
+
+ def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
+ import c.universe._
+
+ val asyncMacro = AsyncMacro(c, futureSystem)
+
+ val code = asyncMacro.asyncTransform[T](
+ body.tree.asInstanceOf[asyncMacro.global.Tree],
+ execContext.tree.asInstanceOf[asyncMacro.global.Tree],
+ fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]).asInstanceOf[Tree]
+
+ for (t <- code)
+ t.pos = t.pos.makeTransparent
+
+ AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
+ c.Expr[futureSystem.Fut[T]](code)
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala
new file mode 100644
index 0000000..4334088
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncId.scala
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import language.experimental.macros
+import scala.reflect.macros.Context
+import scala.reflect.internal.SymbolTable
+
+object AsyncId extends AsyncBase {
+ lazy val futureSystem = IdentityFutureSystem
+ type FS = IdentityFutureSystem.type
+
+ def async[T](body: T) = macro asyncIdImpl[T]
+
+ def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)
+}
+
+/**
+ * A trivial implementation of [[FutureSystem]] that performs computations
+ * on the current thread. Useful for testing.
+ */
+object IdentityFutureSystem extends FutureSystem {
+
+ class Prom[A] {
+ var a: A = _
+ }
+
+ type Fut[A] = A
+ type ExecContext = Unit
+
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
+
+ import universe._
+
+ def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(())))
+
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
+ def execContextType: Type = weakTypeOf[Unit]
+
+ def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
+ new Prom()
+ }
+
+ def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify {
+ prom.splice.a
+ }
+
+ def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t
+
+ def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
+ execContext: Expr[ExecContext]): Expr[Unit] = reify {
+ fun.splice.apply(util.Success(future.splice))
+ Expr[Unit](Literal(Constant(()))).splice
+ }
+
+ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
+ prom.splice.a = value.splice.get
+ Expr[Unit](Literal(Constant(()))).splice
+ }
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
new file mode 100644
index 0000000..23cc611
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -0,0 +1,32 @@
+package scala.async.internal
+
+import scala.tools.nsc.Global
+import scala.tools.nsc.transform.TypingTransformers
+
+object AsyncMacro {
+ def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = {
+ import language.reflectiveCalls
+ val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}]
+ new AsyncMacro {
+ val global: powerContext.universe.type = powerContext.universe
+ val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper
+ val futureSystem: futureSystem0.type = futureSystem0
+ val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global)
+ val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree]
+ }
+ }
+}
+
+private[async] trait AsyncMacro
+ extends TypingTransformers
+ with AnfTransform with TransformUtils with Lifter
+ with ExprBuilder with AsyncTransform with AsyncAnalysis {
+
+ val global: Global
+ val callSiteTyper: global.analyzer.Typer
+ val macroApplication: global.Tree
+
+ def macroPos = macroApplication.pos.makeTransparent
+ def atMacroPos(t: global.Tree) = global.atPos(macroPos)(t)
+
+}
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
new file mode 100644
index 0000000..c755c87
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -0,0 +1,177 @@
+package scala.async.internal
+
+trait AsyncTransform {
+ self: AsyncMacro =>
+
+ import global._
+
+ def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean)
+ (implicit resultType: WeakTypeTag[T]): Tree = {
+
+ reportUnsupportedAwaits(body, report = !cpsFallbackEnabled)
+
+ // Transform to A-normal form:
+ // - no await calls in qualifiers or arguments,
+ // - if/match only used in statement position.
+ val anfTree: Block = anfTransform(body)
+
+ val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(())))
+
+ val applyDefDefDummyBody: DefDef = {
+ val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
+ DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(())))
+ }
+
+ val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
+
+ val stateMachine: ClassDef = {
+ val body: List[Tree] = {
+ val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
+ val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
+ val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
+
+ val apply0DefDef: DefDef = {
+ // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
+ // See SI-1247 for the the optimization that avoids creatio
+ DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
+ }
+ List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef)
+ }
+ val template = {
+ Template(List(stateMachineType), emptyValDef, body)
+ }
+ val t = ClassDef(NoMods, name.stateMachineT, Nil, template)
+ callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(()))))
+ t
+ }
+
+ val asyncBlock: AsyncBlock = {
+ val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol)
+ buildAsyncBlock(anfTree, symLookup)
+ }
+
+ logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))
+
+ def startStateMachine: Tree = {
+ val stateMachineSpliced: Tree = spliceMethodBodies(
+ liftables(asyncBlock.asyncStates),
+ stateMachine,
+ atMacroPos(asyncBlock.onCompleteHandler[T]),
+ atMacroPos(asyncBlock.resumeFunTree[T].rhs)
+ )
+
+ def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
+
+ Block(List[Tree](
+ stateMachineSpliced,
+ ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)),
+ futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext))
+ ),
+ futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
+ }
+
+ val isSimple = asyncBlock.asyncStates.size == 1
+ if (isSimple)
+ futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
+ else
+ startStateMachine
+ }
+
+ def logDiagnostics(anfTree: Tree, states: Seq[String]) {
+ def location = try {
+ macroPos.source.path
+ } catch {
+ case _: UnsupportedOperationException =>
+ macroPos.toString
+ }
+
+ AsyncUtils.vprintln(s"In file '$location':")
+ AsyncUtils.vprintln(s"${macroApplication}")
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
+ states foreach (s => AsyncUtils.vprintln(s))
+ }
+
+ def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree,
+ resumeBody: Tree): Tree = {
+
+ val liftedSyms = liftables.map(_.symbol).toSet
+ val stateMachineClass = tree.symbol
+ liftedSyms.foreach {
+ sym =>
+ if (sym != null) {
+ sym.owner = stateMachineClass
+ if (sym.isModule)
+ sym.moduleClass.owner = stateMachineClass
+ }
+ }
+ // Replace the ValDefs in the splicee with Assigns to the corresponding lifted
+ // fields. Similarly, replace references to them with references to the field.
+ //
+ // This transform will be only be run on the RHS of `def foo`.
+ class UseFields extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case _ if currentOwner == stateMachineClass =>
+ super.transform(tree)
+ case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
+ atOwner(currentOwner) {
+ val fieldSym = tree.symbol
+ val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs))
+ changeOwner(set, tree.symbol, currentOwner)
+ localTyper.typedPos(tree.pos)(set)
+ }
+ case _: DefTree if liftedSyms(tree.symbol) =>
+ EmptyTree
+ case Ident(name) if liftedSyms(tree.symbol) =>
+ val fieldSym = tree.symbol
+ atPos(tree.pos) {
+ gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe)
+ }
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ val liftablesUseFields = liftables.map {
+ case vd: ValDef => vd
+ case x =>
+ val useField = new UseFields()
+ //.substituteSymbols(fromSyms, toSyms)
+ useField.atOwner(stateMachineClass)(useField.transform(x))
+ }
+
+ tree.children.foreach {
+ t =>
+ new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t)
+ }
+ val treeSubst = tree
+
+ def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = {
+ val spliceeAnfFixedOwnerSyms = body
+ val useField = new UseFields()
+ val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms))
+ val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol))
+ treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs))
+ }
+
+ liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol))
+
+ val result0 = transformAt(treeSubst) {
+ case t@Template(parents, self, stats) =>
+ (ctx: analyzer.Context) => {
+ treeCopy.Template(t, parents, self, liftablesUseFields ++ stats)
+ }
+ }
+ val result = transformAt(result0) {
+ case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
+ typedTree
+ case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
+ val res = fixup(dd, changed, ctx)
+ res
+ }
+ result
+ }
+}
diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala
index 1ade5f0..8700bd6 100644
--- a/src/main/scala/scala/async/AsyncUtils.scala
+++ b/src/main/scala/scala/async/internal/AsyncUtils.scala
@@ -1,7 +1,7 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/
-package scala.async
+package scala.async.internal
object AsyncUtils {
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index ca46a83..e0da874 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -1,23 +1,24 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/
-package scala.async
+package scala.async.internal
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import collection.mutable
import language.existentials
+import scala.reflect.api.Universe
+import scala.reflect.api
+import scala.Some
-private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) {
- builder =>
+trait ExprBuilder {
+ builder: AsyncMacro =>
- val utils = TransformUtils[c.type](c)
-
- import c.universe._
- import utils._
+ import global._
import defn._
- lazy val futureSystemOps = futureSystem.mkOps(c)
+ val futureSystem: FutureSystem
+ val futureSystemOps: futureSystem.Ops { val universe: global.type }
val stateAssigner = new StateAssigner
val labelDefStates = collection.mutable.Map[Symbol, Int]()
@@ -27,22 +28,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkHandlerCaseForState: CaseDef
- def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None
+ def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
def stats: List[Tree]
- final def body: c.Tree = stats match {
+ final def allStats: List[Tree] = this match {
+ case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef
+ case _ => stats
+ }
+
+ final def body: Tree = stats match {
case stat :: Nil => stat
case init :+ last => Block(init, last)
}
}
/** A sequence of statements the concludes with a unconditional transition to `nextState` */
- final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int)
+ final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
extends AsyncState {
def mkHandlerCaseForState: CaseDef =
- mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply)
+ mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
override val toString: String =
s"AsyncState #$state, next = $nextState"
@@ -51,7 +57,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements with a conditional transition to the next state, which will represent
* a branch of an `if` or a `match`.
*/
- final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState {
+ final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState {
override def mkHandlerCaseForState: CaseDef =
mkHandlerCase(state, stats)
@@ -62,25 +68,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements that concludes with an `await` call. The `onComplete`
* handler will unconditionally transition to `nestState`.``
*/
- final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int,
- awaitable: Awaitable)
+ final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int,
+ val awaitable: Awaitable, symLookup: SymLookup)
extends AsyncState {
override def mkHandlerCaseForState: CaseDef = {
- val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr),
- c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree
+ val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
+ Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
mkHandlerCase(state, stats :+ callOnComplete)
}
- override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = {
+ override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
val tryGetTree =
Assign(
Ident(awaitable.resultName),
- TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
+ TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
)
/* if (tr.isFailure)
- * result$async.complete(tr.asInstanceOf[Try[T]])
+ * result.complete(tr.asInstanceOf[Try[T]])
* else {
* <resultName> = tr.get.asInstanceOf[<resultType>]
* <nextState>
@@ -88,13 +94,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
*/
val ifIsFailureTree =
- If(Select(Ident(name.tr), Try_isFailure),
+ If(Select(Ident(symLookup.applyTrParam), Try_isFailure),
futureSystemOps.completeProm[T](
- c.Expr[futureSystem.Prom[T]](Ident(name.result)),
- c.Expr[scala.util.Try[T]](
- TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")),
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
+ Expr[scala.util.Try[T]](
+ TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree,
- Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply)
+ Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
)
Some(mkHandlerCase(state, List(ifIsFailureTree)))
@@ -107,19 +113,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/*
* Builder for a single state of an async method.
*/
- final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) {
+ final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
/* Statements preceding an await call. */
- private val stats = ListBuffer[c.Tree]()
+ private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
- private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap))
-
- def +=(stat: c.Tree): this.type = {
+ def +=(stat: Tree): this.type = {
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
- def addStat() = stats += renameReset(stat)
+ def addStat() = stats += stat
stat match {
- case _: DefDef => // these have been lifted.
case Apply(fun, Nil) =>
labelDefStates get fun.symbol match {
case Some(nextState) => nextJumpState = Some(nextState)
@@ -132,22 +135,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def resultWithAwait(awaitable: Awaitable,
nextState: Int): AsyncState = {
- val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr))
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable)
+ new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
}
def resultSimple(nextState: Int): AsyncState = {
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new SimpleAsyncState(stats.toList, state, effectiveNextState)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
}
- def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
- // 1. build changed if-else tree
- // 2. insert that tree at the end of the current state
- val cond = renameReset(condTree)
- def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply)
- this += If(cond, mkBranch(thenState), mkBranch(elseState))
+ def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
+ def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
+ this += If(condTree, mkBranch(thenState), mkBranch(elseState))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -161,23 +160,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param caseStates starting state of the right-hand side of the each case
* @return an `AsyncState` representing the match expression
*/
- def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
+ def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
case CaseDef(pat, guard, rhs) =>
- val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
- case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
- case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
- }
- CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
}
// 2. insert changed match tree at the end of the current state
- this += Match(renameReset(scrutTree), newCases)
+ this += Match(scrutTree, newCases)
new AsyncStateWithoutAwait(stats.toList, state)
}
- def resultWithLabel(startLabelState: Int): AsyncState = {
- this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply)
+ def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
+ this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -194,24 +190,22 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param expr the last expression of the block
* @param startState the start state
* @param endState the state to continue with
- * @param toRename a `Map` for renaming the given key symbols to the mangled value names
*/
- final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
- private val toRename: Map[Symbol, c.Name]) {
+ final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int,
+ private val symLookup: SymLookup) {
val asyncStates = ListBuffer[AsyncState]()
- var stateBuilder = new AsyncStateBuilder(startState, toRename)
+ var stateBuilder = new AsyncStateBuilder(startState, symLookup)
var currState = startState
- /* TODO Fall back to CPS plug-in if tree contains an `await` call. */
- def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
+ def checkForUnsupportedAwait(tree: Tree) = if (tree exists {
case Apply(fun, _) if isAwait(fun) => true
case _ => false
- }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
+ }) abort(tree.pos, "await must not be used in this position")
def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = {
val (nestedStats, nestedExpr) = statsAndExpr(nestedTree)
- new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename)
+ new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup)
}
import stateAssigner.nextState
@@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
+ case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val afterAwaitState = nextState()
- val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe)
+ val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
currState = afterAwaitState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
-
- case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol =>
- checkForUnsupportedAwait(rhs)
- stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
@@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterIfState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case Match(scrutinee, cases) if stat exists isAwait =>
checkForUnsupportedAwait(scrutinee)
@@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val afterMatchState = nextState()
asyncStates +=
- stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
+ stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
for ((cas, num) <- cases.zipWithIndex) {
val (stats, expr) = statsAndExpr(cas.body)
@@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterMatchState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
val startLabelState = nextState()
val afterLabelState = nextState()
- asyncStates += stateBuilder.resultWithLabel(startLabelState)
+ asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
labelDefStates(ld.symbol) = startLabelState
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
asyncStates ++= builder.asyncStates
currState = afterLabelState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
@@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
trait AsyncBlock {
def asyncStates: List[AsyncState]
- def onCompleteHandler[T: c.WeakTypeTag]: Tree
+ def onCompleteHandler[T: WeakTypeTag]: Tree
+
+ def resumeFunTree[T]: DefDef
+ }
- def resumeFunTree[T]: Tree
+ case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
+ def stateMachineMember(name: TermName): Symbol =
+ stateMachineClass.info.member(name)
+ def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name))
}
- def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = {
+ def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = {
val Block(stats, expr) = block
val startState = stateAssigner.nextState()
val endState = Int.MaxValue
- val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename)
+ val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
new AsyncBlock {
def asyncStates = blockBuilder.asyncStates.toList
@@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkCombinedHandlerCases[T]: List[CaseDef] = {
val caseForLastState: CaseDef = {
val lastState = asyncStates.last
- val lastStateBody = c.Expr[T](lastState.body)
+ val lastStateBody = Expr[T](lastState.body)
val rhs = futureSystemOps.completeProm(
- c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice)))
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice)))
mkHandlerCase(lastState.state, rhs.tree)
}
asyncStates.toList match {
@@ -327,18 +323,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val initStates = asyncStates.init
/**
- * // assumes tr: Try[Any] is in scope.
- * //
- * state match {
- * case 0 => {
- * x11 = tr.get.asInstanceOf[Double];
- * state = 1;
- * resume()
- * }
- */
- def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
-
- /**
* def resume(): Unit = {
* try {
* state match {
@@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
* }
*/
- def resumeFunTree[T]: Tree =
+ def resumeFunTree[T]: DefDef =
DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
Try(
- Match(Ident(name.state), mkCombinedHandlerCases[T]),
+ Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
List(
CaseDef(
- Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))),
- EmptyTree,
+ Bind(name.t, Ident(nme.WILDCARD)),
+ Apply(Ident(defn.NonFatalClass), List(Ident(name.t))),
Block(List({
- val t = c.Expr[Throwable](Ident(name.tr))
- futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree
- }), c.literalUnit.tree))), EmptyTree))
+ val t = Expr[Throwable](Ident(name.t))
+ futureSystemOps.completeProm[T](
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree
+ }), literalUnit))), EmptyTree))
+
+ /**
+ * // assumes tr: Try[Any] is in scope.
+ * //
+ * state match {
+ * case 0 => {
+ * x11 = tr.get.asInstanceOf[Double];
+ * state = 1;
+ * resume()
+ * }
+ */
+ def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
}
}
@@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
case _ => false
}
- private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
-
- private val internalSyms = origTree.collect {
- case dt: DefTree => dt.symbol
- }
+ case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
- private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms)
+ private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil)
- private def mkResumeApply = Apply(Ident(name.resume), Nil)
+ private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
+ Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
- private def mkStateTree(nextState: Int): c.Tree =
- Assign(Ident(name.state), c.literal(nextState).tree)
+ private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
+ mkHandlerCase(num, Block(rhs, literalUnit))
- private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef =
- mkHandlerCase(num, Block(rhs, c.literalUnit.tree))
+ private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
+ CaseDef(Literal(Constant(num)), EmptyTree, rhs)
- private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
- CaseDef(c.literal(num).tree, EmptyTree, rhs)
+ private def literalUnit = Literal(Constant(()))
}
diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala
index a050bec..101b7bf 100644
--- a/src/main/scala/scala/async/FutureSystem.scala
+++ b/src/main/scala/scala/async/internal/FutureSystem.scala
@@ -1,11 +1,12 @@
/*
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/
-package scala.async
+package scala.async.internal
import scala.language.higherKinds
import scala.reflect.macros.Context
+import scala.reflect.internal.SymbolTable
/**
* An abstraction over a future system.
@@ -14,7 +15,7 @@ import scala.reflect.macros.Context
* customize the code generation.
*
* The API mirrors that of `scala.concurrent.Future`, see the instance
- * [[scala.async.ScalaConcurrentFutureSystem]] for an example of how
+ * [[ScalaConcurrentFutureSystem]] for an example of how
* to implement this.
*/
trait FutureSystem {
@@ -26,12 +27,10 @@ trait FutureSystem {
type ExecContext
trait Ops {
- val context: reflect.macros.Context
+ val universe: reflect.internal.SymbolTable
- import context.universe._
-
- /** Lookup the execution context, typically with an implicit search */
- def execContext: Expr[ExecContext]
+ import universe._
+ def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree))
def promType[A: WeakTypeTag]: Type
def execContextType: Type
@@ -52,13 +51,14 @@ trait FutureSystem {
/** Complete a promise with a value */
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit]
- def spawn(tree: context.Tree): context.Tree =
- future(context.Expr[Unit](tree))(execContext).tree
+ def spawn(tree: Tree, execContext: Tree): Tree =
+ future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree
+ // TODO Why is this needed?
def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]]
}
- def mkOps(c: Context): Ops { val context: c.type }
+ def mkOps(c: SymbolTable): Ops { val universe: c.type }
}
object ScalaConcurrentFutureSystem extends FutureSystem {
@@ -69,18 +69,13 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
type Fut[A] = Future[A]
type ExecContext = ExecutionContext
- def mkOps(c: Context): Ops {val context: c.type} = new Ops {
- val context: c.type = c
-
- import context.universe._
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
- def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match {
- case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext")
- case context => context
- })
+ import universe._
- def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]]
- def execContextType: Type = c.weakTypeOf[ExecutionContext]
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]]
+ def execContextType: Type = weakTypeOf[ExecutionContext]
def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
Promise[A]()
@@ -101,7 +96,7 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
prom.splice.complete(value.splice)
- context.literalUnit.splice
+ Expr[Unit](Literal(Constant(()))).splice
}
def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify {
@@ -109,49 +104,3 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
}
}
}
-
-/**
- * A trivial implementation of [[scala.async.FutureSystem]] that performs computations
- * on the current thread. Useful for testing.
- */
-object IdentityFutureSystem extends FutureSystem {
-
- class Prom[A](var a: A)
-
- type Fut[A] = A
- type ExecContext = Unit
-
- def mkOps(c: Context): Ops {val context: c.type} = new Ops {
- val context: c.type = c
-
- import context.universe._
-
- def execContext: Expr[ExecContext] = c.literalUnit
-
- def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]]
- def execContextType: Type = c.weakTypeOf[Unit]
-
- def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
- new Prom(null.asInstanceOf[A])
- }
-
- def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify {
- prom.splice.a
- }
-
- def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t
-
- def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
- execContext: Expr[ExecContext]): Expr[Unit] = reify {
- fun.splice.apply(util.Success(future.splice))
- context.literalUnit.splice
- }
-
- def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
- prom.splice.a = value.splice.get
- context.literalUnit.splice
- }
-
- def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
- }
-}
diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala
new file mode 100644
index 0000000..f49dcbb
--- /dev/null
+++ b/src/main/scala/scala/async/internal/Lifter.scala
@@ -0,0 +1,150 @@
+package scala.async.internal
+
+trait Lifter {
+ self: AsyncMacro =>
+ import global._
+
+ /**
+ * Identify which DefTrees are used (including transitively) which are declared
+ * in some state but used (including transitively) in another state.
+ *
+ * These will need to be lifted to class members of the state machine.
+ */
+ def liftables(asyncStates: List[AsyncState]): List[Tree] = {
+ object companionship {
+ private val companions = collection.mutable.Map[Symbol, Symbol]()
+ private val companionsInverse = collection.mutable.Map[Symbol, Symbol]()
+ private def record(sym1: Symbol, sym2: Symbol) {
+ companions(sym1) = sym2
+ companions(sym2) = sym1
+ }
+
+ def record(defs: List[Tree]) {
+ // Keep note of local companions so we rename them consistently
+ // when lifting.
+ val comps = for {
+ cd@ClassDef(_, _, _, _) <- defs
+ md@ModuleDef(_, _, _) <- defs
+ if (cd.name.toTermName == md.name)
+ } record(cd.symbol, md.symbol)
+ }
+ def companionOf(sym: Symbol): Symbol = {
+ companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol)
+ }
+ }
+
+
+ val defs: Map[Tree, Int] = {
+ /** Collect the DefTrees directly enclosed within `t` that have the same owner */
+ def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
+ case dt: DefTree => dt :: Nil
+ case _: Function => Nil
+ case t =>
+ val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_))
+ companionship.record(childDefs)
+ childDefs
+ }
+ asyncStates.flatMap {
+ asyncState =>
+ val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
+ defs.map((_, asyncState.state))
+ }.toMap
+ }
+
+ // In which block are these symbols defined?
+ val symToDefiningState: Map[Symbol, Int] = defs.map {
+ case (k, v) => (k.symbol, v)
+ }
+
+ // The definitions trees
+ val symToTree: Map[Symbol, Tree] = defs.map {
+ case (k, v) => (k.symbol, k)
+ }
+
+ // The direct references of each definition tree
+ val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
+ case tree => (tree.symbol, tree.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
+ })
+ }.toMap
+
+ // The direct references of each block, excluding references of `DefTree`-s which
+ // are already accounted for.
+ val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
+ val refs: List[(Int, Symbol)] = asyncStates.flatMap(
+ asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
+ })
+ )
+ toMultiMap(refs)
+ }
+
+ def liftableSyms: Set[Symbol] = {
+ val liftableMutableSet = collection.mutable.Set[Symbol]()
+ def markForLift(sym: Symbol) {
+ if (!liftableMutableSet(sym)) {
+ liftableMutableSet += sym
+
+ // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars
+ // stays in its original location, so things that it refers to need not be lifted.
+ if (!(sym.isVal || sym.isVar))
+ defSymToReferenced(sym).foreach(sym2 => markForLift(sym2))
+ }
+ }
+ // Start things with DefTrees directly referenced from statements from other states...
+ val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
+ case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
+ }
+ // .. and likewise for DefTrees directly referenced by other DefTrees from other states
+ val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
+ case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
+ }
+ // Mark these for lifting, which will follow transitive references.
+ (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
+ liftableMutableSet.toSet
+ }
+
+ val lifted = liftableSyms.map(symToTree).toList.map {
+ case vd@ValDef(_, _, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = vd.symbol
+ sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL)
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.modifyInfo(_.deconst)
+ ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos)
+ case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = dd.symbol
+ sym.name = this.name.fresh(sym.name.toTermName)
+ sym.setFlag(PRIVATE | LOCAL)
+ DefDef(dd.symbol, rhs).setPos(dd.pos)
+ case cd@ClassDef(_, _, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = cd.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ companionship.companionOf(cd.symbol) match {
+ case NoSymbol =>
+ case moduleSymbol =>
+ moduleSymbol.name = sym.name.toTermName
+ moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName
+ }
+ ClassDef(cd.symbol, impl).setPos(cd.pos)
+ case md@ModuleDef(_, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = md.symbol
+ companionship.companionOf(md.symbol) match {
+ case NoSymbol =>
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.moduleClass.name = sym.name.toTypeName
+ case classSymbol => // will be renamed by `case ClassDef` above.
+ }
+ ModuleDef(md.symbol, impl).setPos(md.pos)
+ case td@TypeDef(_, _, _, rhs) =>
+ import reflect.internal.Flags._
+ val sym = td.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ TypeDef(td.symbol, rhs).setPos(td.pos)
+ }
+ lifted
+ }
+}
diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala
index bc60a6d..cdde7a4 100644
--- a/src/main/scala/scala/async/StateAssigner.scala
+++ b/src/main/scala/scala/async/internal/StateAssigner.scala
@@ -2,7 +2,7 @@
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
*/
-package scala.async
+package scala.async.internal
private[async] final class StateAssigner {
private var current = -1
@@ -11,4 +11,4 @@ private[async] final class StateAssigner {
current += 1
current
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
new file mode 100644
index 0000000..70237bc
--- /dev/null
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -0,0 +1,251 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async.internal
+
+import scala.reflect.macros.Context
+import reflect.ClassTag
+import scala.reflect.macros.runtime.AbortMacroException
+
+/**
+ * Utilities used in both `ExprBuilder` and `AnfTransform`.
+ */
+private[async] trait TransformUtils {
+ self: AsyncMacro =>
+
+ import global._
+
+ object name {
+ val resume = newTermName("resume")
+ val apply = newTermName("apply")
+ val matchRes = "matchres"
+ val ifRes = "ifres"
+ val await = "await"
+ val bindSuffix = "$bind"
+
+ val state = newTermName("state")
+ val result = newTermName("result")
+ val execContext = newTermName("execContext")
+ val stateMachine = newTermName(fresh("stateMachine"))
+ val stateMachineT = stateMachine.toTypeName
+ val tr = newTermName("tr")
+ val t = newTermName("throwable")
+
+ def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
+
+ def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString
+ }
+
+ def isAwait(fun: Tree) =
+ fun.symbol == defn.Async_await
+
+ private lazy val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ private def isByName(fun: Tree): ((Int, Int) => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
+ else {
+ val paramss = fun.tpe.paramss
+ val byNamess = paramss.map(_.map(_.isByNameParam))
+ (i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
+ }
+ }
+ private def argName(fun: Tree): ((Int, Int) => String) = {
+ val paramss = fun.tpe.paramss
+ val namess = paramss.map(_.map(_.name.toString))
+ (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
+ }
+
+ def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t))
+
+ object defn {
+ def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
+ Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ }
+
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify {
+ self.splice.contains(elem.splice)
+ }
+
+ def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
+ self.splice.apply(arg.splice)
+ }
+
+ def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
+ self.splice == other.splice
+ }
+
+ def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
+ self.splice.get
+ }
+
+ val TryClass = rootMirror.staticClass("scala.util.Try")
+ val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol)
+ val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol)
+ val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
+ val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val AsyncClass = rootMirror.staticClass("scala.async.internal.AsyncBase")
+ val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol)
+ }
+
+ def isSafeToInline(tree: Tree) = {
+ treeInfo.isExprSafeToInline(tree)
+ }
+
+ /** Map a list of arguments to:
+ * - A list of argument Trees
+ * - A list of auxillary results.
+ *
+ * The function unwraps and rewraps the `arg :_*` construct.
+ *
+ * @param args The original argument trees
+ * @param f A function from argument (with '_*' unwrapped) and argument index to argument.
+ * @tparam A The type of the auxillary result
+ */
+ private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
+ args match {
+ case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) =>
+ val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip
+ val exprs = argExprs :+ atPos(lastArgExpr.pos.makeTransparent)(Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)))
+ (a, exprs)
+ case args =>
+ args.zipWithIndex.map(f.tupled).unzip
+ }
+ }
+
+ case class Arg(expr: Tree, isByName: Boolean, argName: String)
+
+ /**
+ * Transform a list of argument lists, producing the transformed lists, and lists of auxillary
+ * results.
+ *
+ * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will
+ * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`.
+ *
+ * @param fun The function being applied
+ * @param argss The argument lists
+ * @return (auxillary results, mapped argument trees)
+ */
+ def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = {
+ val isByNamess: (Int, Int) => Boolean = isByName(fun)
+ val argNamess: (Int, Int) => String = argName(fun)
+ argss.zipWithIndex.map { case (args, i) =>
+ mapArguments[A](args) {
+ (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j)))
+ }
+ }.unzip
+ }
+
+
+ def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
+ case Block(stats, expr) => (stats, expr)
+ case _ => (List(tree), Literal(Constant(())))
+ }
+
+ def emptyConstructor: DefDef = {
+ val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
+ DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(()))))
+ }
+
+ def applied(className: String, types: List[Type]): AppliedTypeTree =
+ AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_)))
+
+ /** Descends into the regions of the tree that are subject to the
+ * translation to a state machine by `async`. When a nested template,
+ * function, or by-name argument is encountered, the descent stops,
+ * and `nestedClass` etc are invoked.
+ */
+ trait AsyncTraverser extends Traverser {
+ def nestedClass(classDef: ClassDef) {
+ }
+
+ def nestedModule(module: ModuleDef) {
+ }
+
+ def nestedMethod(module: DefDef) {
+ }
+
+ def byNameArgument(arg: Tree) {
+ }
+
+ def function(function: Function) {
+ }
+
+ def patMatFunction(tree: Match) {
+ }
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case dd: DefDef => nestedMethod(dd)
+ case fun: Function => function(fun)
+ case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ val isInByName = isByName(fun)
+ for ((args, i) <- argss.zipWithIndex) {
+ for ((arg, j) <- args.zipWithIndex) {
+ if (!isInByName(i, j)) traverse(arg)
+ else byNameArgument(arg)
+ }
+ }
+ traverse(fun)
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+
+ def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg)
+
+ abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) {
+ currentOwner = callSiteTyper.context.owner
+
+ def currOwner: Symbol = currentOwner
+
+ localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit))
+ }
+
+ def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = {
+ object trans extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = {
+ if (f.isDefinedAt(tree)) {
+ f(tree)(localTyper.context)
+ } else super.transform(tree)
+ }
+ }
+ trans.transform(tree)
+ }
+
+ def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = {
+ new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree)
+ tree
+ }
+
+ class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol)
+ extends ChangeOwnerTraverser(oldowner, newowner) {
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case _: DefTree => change(tree.symbol.moduleClass)
+ case _ =>
+ }
+ super.traverse(tree)
+ }
+ }
+
+ def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] =
+ as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap
+
+ // Attributed version of `TreeGen#mkCastPreservingAnnotations`
+ def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = {
+ atPos(tree.pos) {
+ val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias)
+ Typed(casted, TypeTree(tp)).setType(tp)
+ }
+ }
+}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index deaee03..770c0f9 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -7,6 +7,7 @@ package scala.async
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
import AsyncId._
import tools.reflect.ToolBox
@@ -15,9 +16,9 @@ class TreeInterrogation {
@Test
def `a minimal set of vals are lifted to vars`() {
val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes")
+ val tb = mkToolbox(s"-cp ${toolboxClasspath}")
val tree = tb.parse(
- """| import _root_.scala.async.AsyncId._
+ """| import _root_.scala.async.internal.AsyncId._
| async {
| val x = await(1)
| val y = x * 2
@@ -40,8 +41,7 @@ class TreeInterrogation {
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
- varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
+ varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1$1", "await$2$1"))
val defDefs = tree1.collect {
case t: Template =>
@@ -52,7 +52,7 @@ class TreeInterrogation {
&& !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name
}
}.flatten
- defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume$async", "<init>"))
+ defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume", "<init>"))
}
}
@@ -68,17 +68,15 @@ object TreeInterrogation extends App {
withDebug {
val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:flatten")
+ val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
import scala.async.Async._
val tree = tb.parse(
- """ import _root_.scala.async.AsyncId.{async, await}
- | def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
- | val res = async {
- | var i = 0
- | def get = async {i += 1; i}
- | foo[Int](await(get))(await(get) :: Nil : _*)
+ """ import _root_.scala.async.internal.AsyncId.{async, await}
+ | import reflect.runtime.universe._
+ | async {
+ | implicit def view(a: Int): String = ""
+ | await(0).length
| }
- | res
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
index 2569303..6ebc9ca 100644
--- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
+++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
@@ -5,121 +5,33 @@
package scala.async
package neg
-/**
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
@RunWith(classOf[JUnit4])
class LocalClasses0Spec {
-
@Test
- def `reject a local class`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | case class Person(name: String)
- | }
- """.stripMargin
- }
+ def localClassCrashIssue16() {
+ import AsyncId.{async, await}
+ async {
+ class B { def f = 1 }
+ await(new B()).f
+ } mustBe 1
}
@Test
- def `reject a local class 2`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | case class Person(name: String)
- | val fut = Future { 5 }
- | val x = await(fut)
- | x
- | }
- """.stripMargin
- }
+ def nestedCaseClassAndModuleAllowed() {
+ import AsyncId.{await, async}
+ async {
+ trait Base { def base = 0}
+ await(0)
+ case class Person(name: String) extends Base
+ val fut = async { "bob" }
+ val x = Person(await(fut))
+ x.base
+ x.name
+ } mustBe "bob"
}
-
- @Test
- def `reject a local class 3`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = await(fut)
- | case class Person(name: String)
- | x
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a local class with symbols in its name`() {
- expectError("Local case class :: illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = await(fut)
- | case class ::(name: String)
- | x
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a nested local class`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = 2 + 2
- | var y = 0
- | if (x > 0) {
- | case class Person(name: String)
- | y = await(fut)
- | } else {
- | y = x
- | }
- | y
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a local singleton object`() {
- expectError("Local object Person illegal within `async` block") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | object Person { val name = "Joe" }
- | }
- """.stripMargin
- }
- }
-
}
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index b0d5fde..ba388c5 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -25,7 +25,7 @@ class NakedAwait {
def `await not allowed in by-name argument`() {
expectError("await must not be used under a by-name argument.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| def foo(a: Int)(b: => Int) = 0
| async { foo(0)(await(0)) }
""".stripMargin
@@ -36,7 +36,7 @@ class NakedAwait {
def `await not allowed in boolean short circuit argument 1`() {
expectError("await must not be used under a by-name argument.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { true && await(false) }
""".stripMargin
}
@@ -46,7 +46,7 @@ class NakedAwait {
def `await not allowed in boolean short circuit argument 2`() {
expectError("await must not be used under a by-name argument.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { true || await(false) }
""".stripMargin
}
@@ -56,7 +56,7 @@ class NakedAwait {
def nestedObject() {
expectError("await must not be used under a nested object.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { object Nested { await(false) } }
""".stripMargin
}
@@ -66,7 +66,7 @@ class NakedAwait {
def nestedTrait() {
expectError("await must not be used under a nested trait.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { trait Nested { await(false) } }
""".stripMargin
}
@@ -76,7 +76,7 @@ class NakedAwait {
def nestedClass() {
expectError("await must not be used under a nested class.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { class Nested { await(false) } }
""".stripMargin
}
@@ -86,7 +86,7 @@ class NakedAwait {
def nestedFunction() {
expectError("await must not be used under a nested function.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { () => { await(false) } }
""".stripMargin
}
@@ -96,7 +96,7 @@ class NakedAwait {
def nestedPatMatFunction() {
expectError("await must not be used under a nested class.") { // TODO more specific error message
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { { case x => { await(false) } } : PartialFunction[Any, Any] }
""".stripMargin
}
@@ -106,7 +106,7 @@ class NakedAwait {
def tryBody() {
expectError("await must not be used under a try/catch.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { try { await(false) } catch { case _ => } }
""".stripMargin
}
@@ -116,7 +116,7 @@ class NakedAwait {
def catchBody() {
expectError("await must not be used under a try/catch.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { try { () } catch { case _ => await(false) } }
""".stripMargin
}
@@ -126,17 +126,27 @@ class NakedAwait {
def finallyBody() {
expectError("await must not be used under a try/catch.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { try { () } finally { await(false) } }
""".stripMargin
}
}
@Test
+ def guard() {
+ expectError("await must not be used under a pattern guard.") {
+ """
+ | import _root_.scala.async.internal.AsyncId._
+ | async { 1 match { case _ if await(true) => } }
+ """.stripMargin
+ }
+ }
+
+ @Test
def nestedMethod() {
expectError("await must not be used under a nested method.") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| async { def foo = await(false) }
""".stripMargin
}
@@ -146,7 +156,7 @@ class NakedAwait {
def returnIllegal() {
expectError("return is illegal") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| def foo(): Any = async { return false }
| ()
|
@@ -158,7 +168,7 @@ class NakedAwait {
def lazyValIllegal() {
expectError("lazy vals are illegal") {
"""
- | import _root_.scala.async.AsyncId._
+ | import _root_.scala.async.internal.AsyncId._
| def foo(): Any = async { val x = { lazy val y = 0; y } }
| ()
|
diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala
index 4a7a958..7c42024 100644
--- a/src/test/scala/scala/async/package.scala
+++ b/src/test/scala/scala/async/package.scala
@@ -42,7 +42,22 @@ package object async {
m.mkToolBox(options = compileOptions)
}
- def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) {
+ def scalaBinaryVersion: String = {
+ val Pattern = """(\d+\.\d+)\..*""".r
+ scala.util.Properties.versionNumberString match {
+ case Pattern(v) => v
+ case _ => ""
+ }
+ }
+
+ def toolboxClasspath = {
+ val f = new java.io.File(s"target/scala-${scalaBinaryVersion}/classes")
+ if (!f.exists) sys.error(s"output directory ${f.getAbsolutePath} does not exist.")
+ f.getAbsolutePath
+ }
+
+ def expectError(errorSnippet: String, compileOptions: String = "",
+ baseCompileOptions: String = s"-cp ${toolboxClasspath}")(code: String) {
intercept[ToolBoxError] {
eval(code, compileOptions + " " + baseCompileOptions)
}.getMessage mustContain errorSnippet
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 7be6299..c8cec28 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -13,6 +13,7 @@ import scala.async.Async.{async, await}
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
+import scala.async.internal.AsyncId
class AnfTestClass {
@@ -114,8 +115,6 @@ class AnfTransformSpec {
@Test
def `inlining block does not produce duplicate definition`() {
- import scala.async.AsyncId
-
AsyncId.async {
val f = 12
val x = AsyncId.await(f)
@@ -132,8 +131,6 @@ class AnfTransformSpec {
@Test
def `inlining block in tail position does not produce duplicate definition`() {
- import scala.async.AsyncId
-
AsyncId.async {
val f = 12
val x = AsyncId.await(f)
@@ -176,7 +173,7 @@ class AnfTransformSpec {
@Test
def nestedAwaitAsBareExpression() {
import ExecutionContext.Implicits.global
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
val result = async {
await(await("").isEmpty)
}
@@ -186,7 +183,7 @@ class AnfTransformSpec {
@Test
def nestedAwaitInBlock() {
import ExecutionContext.Implicits.global
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
val result = async {
()
await(await("").isEmpty)
@@ -197,7 +194,7 @@ class AnfTransformSpec {
@Test
def nestedAwaitInIf() {
import ExecutionContext.Implicits.global
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
val result = async {
if ("".isEmpty)
await(await("").isEmpty)
@@ -208,7 +205,7 @@ class AnfTransformSpec {
@Test
def byNameExpressionsArentLifted() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo(ignored: => Any, b: Int) = b
val result = async {
foo(???, await(1))
@@ -218,7 +215,7 @@ class AnfTransformSpec {
@Test
def evaluationOrderRespected() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo(a: Int, b: Int) = (a, b)
val result = async {
var i = 0
@@ -233,19 +230,19 @@ class AnfTransformSpec {
@Test
def awaitInNonPrimaryParamSection1() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0"
val res = async {
var i = 0
def get = {i += 1; i}
- foo(get)(get)
+ foo(get)(await(get))
}
res mustBe "a0 = 1, b0 = 2"
}
@Test
def awaitInNonPrimaryParamSection2() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}"
val res = async {
var i = 0
@@ -257,7 +254,7 @@ class AnfTransformSpec {
@Test
def awaitInNonPrimaryParamSectionWithLazy1() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo[T](a: => Int)(b: Int) = b
val res = async {
def get = async {0}
@@ -268,7 +265,7 @@ class AnfTransformSpec {
@Test
def awaitInNonPrimaryParamSectionWithLazy2() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo[T](a: Int)(b: => Int) = a
val res = async {
def get = async {0}
@@ -279,7 +276,7 @@ class AnfTransformSpec {
@Test
def awaitWithLazy() {
- import _root_.scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo[T](a: Int, b: => Int) = a
val res = async {
def get = async {0}
@@ -290,7 +287,7 @@ class AnfTransformSpec {
@Test
def awaitOkInReciever() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
class Foo { def bar(a: Int)(b: Int) = a + b }
async {
await(async(new Foo)).bar(1)(2)
@@ -299,7 +296,7 @@ class AnfTransformSpec {
@Test
def namedArgumentsRespectEvaluationOrder() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
def foo(a: Int, b: Int) = (a, b)
val result = async {
var i = 0
@@ -314,7 +311,7 @@ class AnfTransformSpec {
@Test
def namedAndDefaultArgumentsRespectEvaluationOrder() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
var i = 0
def next() = {
i += 1;
@@ -332,7 +329,7 @@ class AnfTransformSpec {
@Test
def repeatedParams1() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
var i = 0
def foo(a: Int, b: Int*) = b.toList
def id(i: Int) = i
@@ -343,7 +340,7 @@ class AnfTransformSpec {
@Test
def repeatedParams2() {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
var i = 0
def foo(a: Int, b: Int*) = b.toList
def id(i: Int) = i
@@ -351,4 +348,64 @@ class AnfTransformSpec {
foo(await(0), List(id(1), id(2), id(3)): _*)
} mustBe (List(1, 2, 3))
}
+
+ @Test
+ def awaitInThrow() {
+ import _root_.scala.async.internal.AsyncId.{async, await}
+ intercept[Exception](
+ async {
+ throw new Exception("msg: " + await(0))
+ }
+ ).getMessage mustBe "msg: 0"
+ }
+
+ @Test
+ def awaitInTyped() {
+ import _root_.scala.async.internal.AsyncId.{async, await}
+ async {
+ (("msg: " + await(0)): String).toString
+ } mustBe "msg: 0"
+ }
+
+
+ @Test
+ def awaitInAssign() {
+ import _root_.scala.async.internal.AsyncId.{async, await}
+ async {
+ var x = 0
+ x = await(1)
+ x
+ } mustBe 1
+ }
+
+ @Test
+ def caseBodyMustBeTypedAsUnit() {
+ import _root_.scala.async.internal.AsyncId.{async, await}
+ val Up = 1
+ val Down = 2
+ val sign = async {
+ await(1) match {
+ case Up => 1.0
+ case Down => -1.0
+ }
+ }
+ sign mustBe 1.0
+ }
+
+ @Test
+ def awaitInImplicitApply() {
+ val tb = mkToolbox(s"-cp ${toolboxClasspath}")
+ val tree = tb.typeCheck(tb.parse {
+ """
+ | import language.implicitConversions
+ | import _root_.scala.async.internal.AsyncId.{async, await}
+ | implicit def view(a: Int): String = ""
+ | async {
+ | await(0).length
+ | }
+ """.stripMargin
+ })
+ val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x }
+ applyImplicitView.map(_.toString) mustBe List("view(a$1)")
+ }
}
diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala
index 9d1df21..8081ee7 100644
--- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala
+++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala
@@ -9,11 +9,12 @@ package hygiene
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
+import scala.async.internal.AsyncId
@RunWith(classOf[JUnit4])
class HygieneSpec {
- import scala.async.AsyncId.{async, await}
+ import AsyncId.{async, await}
@Test
def `is hygenic`() {
diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
index e2b1ca6..fc438a1 100644
--- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
+++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
@@ -13,6 +13,7 @@ import scala.async.Async.{async, await}
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
class TestIfElseClass {
diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala
index 1f1033a..b8d88fb 100644
--- a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala
+++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala
@@ -9,6 +9,7 @@ package ifelse0
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
@RunWith(classOf[JUnit4])
class WhileSpec {
@@ -64,4 +65,4 @@ class WhileSpec {
}
result mustBe (100)
}
-} \ No newline at end of file
+}
diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala
index 7624838..7c392ab 100644
--- a/src/test/scala/scala/async/run/match0/Match0.scala
+++ b/src/test/scala/scala/async/run/match0/Match0.scala
@@ -13,6 +13,7 @@ import scala.async.Async.{async, await}
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
class TestMatchClass {
@@ -111,4 +112,38 @@ class MatchSpec {
}
result mustBe (3)
}
+
+ @Test def duplicateBindName() {
+ import AsyncId.{async, await}
+ def m4(m: Any) = async {
+ m match {
+ case buf: String =>
+ await(0)
+ case buf: Double =>
+ await(2)
+ }
+ }
+ m4("") mustBe 0
+ }
+
+ @Test def bugCastBoxedUnitToStringMatch() {
+ import scala.async.internal.AsyncId.{async, await}
+ def foo = async {
+ val p2 = await(5)
+ "foo" match {
+ case p3: String =>
+ p2.toString
+ }
+ }
+ foo mustBe "5"
+ }
+
+ @Test def bugCastBoxedUnitToStringIf() {
+ import scala.async.internal.AsyncId.{async, await}
+ def foo = async {
+ val p2 = await(5)
+ if (true) p2.toString else p2.toString
+ }
+ foo mustBe "5"
+ }
}
diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
index ee0a78e..409f70a 100644
--- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
+++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
@@ -5,6 +5,7 @@ package nesteddef
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
+import scala.async.internal.AsyncId
@RunWith(classOf[JUnit4])
class NestedDef {
@@ -37,4 +38,60 @@ class NestedDef {
}
result mustBe ((0d, 44d, 2))
}
+
+ // We must lift `foo` and `bar` in the next two tests.
+ @Test
+ def nestedDefTransitive1() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def bar = a
+ def foo = bar
+ foo
+ }
+ result mustBe 0
+ }
+
+ @Test
+ def nestedDefTransitive2() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def bar = a
+ def foo = bar
+ 0
+ }
+ result mustBe 0
+ }
+
+
+ // checking that our use/definition analysis doesn't cycle.
+ @Test
+ def mutuallyRecursive1() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def foo: Int = if (true) 0 else bar
+ def bar: Int = if (true) 0 else foo
+ bar
+ }
+ result mustBe 0
+ }
+
+ // checking that our use/definition analysis doesn't cycle.
+ @Test
+ def mutuallyRecursive2() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ def foo: Int = if (true) 0 else bar
+ def bar: Int = if (true) 0 else foo
+ val x = await(a) - 1
+ bar
+ }
+ result mustBe 0
+ }
}
diff --git a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala
index e2c69d0..ba9c9be 100644
--- a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala
+++ b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala
@@ -6,6 +6,7 @@ package scala.async
package run
package noawait
+import scala.async.internal.AsyncId
import AsyncId._
import org.junit.Test
import org.junit.runner.RunWith
diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala
index 83f5a2d..ec2278f 100644
--- a/src/test/scala/scala/async/run/toughtype/ToughType.scala
+++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala
@@ -13,6 +13,7 @@ import scala.async.Async._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
+import scala.async.internal.AsyncId
object ToughTypeObject {
@@ -67,4 +68,74 @@ class ToughTypeSpec {
await(f(2))
} mustBe 3
}
+
+ @Test def existentialBindIssue19() {
+ import AsyncId.{await, async}
+ def m7(a: Any) = async {
+ a match {
+ case s: Seq[_] =>
+ val x = s.size
+ var ss = s
+ ss = s
+ await(x)
+ }
+ }
+ m7(Nil) mustBe 0
+ }
+
+ @Test def existentialBind2Issue19() {
+ import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global
+ def conjure[T]: T = null.asInstanceOf[T]
+
+ def m3 = async {
+ val p: List[Option[_]] = conjure[List[Option[_]]]
+ await(future(1))
+ }
+
+ def m4 = async {
+ await(future[List[_]](Nil))
+ }
+ }
+
+ @Test def singletonTypeIssue17() {
+ import AsyncId.{async, await}
+ class A { class B }
+ async {
+ val a = new A
+ def foo(b: a.B) = 0
+ await(foo(new a.B))
+ }
+ }
+
+ @Test def existentialMatch() {
+ import AsyncId.{async, await}
+ trait Container[+A]
+ case class ContainerImpl[A](value: A) extends Container[A]
+ def foo: Container[_] = async {
+ val a: Any = List(1)
+ a match {
+ case buf: Seq[_] =>
+ val foo = await(5)
+ val e0 = buf(0)
+ ContainerImpl(e0)
+ }
+ }
+ foo
+ }
+
+ @Test def existentialIfElse0() {
+ import AsyncId.{async, await}
+ trait Container[+A]
+ case class ContainerImpl[A](value: A) extends Container[A]
+ def foo: Container[_] = async {
+ val a: Any = List(1)
+ if (true) {
+ val buf: Seq[_] = List(1)
+ val foo = await(5)
+ val e0 = buf(0)
+ ContainerImpl(e0)
+ } else ???
+ }
+ foo
+ }
}