diff options
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 68 | ||||
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 99 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 66 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncUtils.scala | 4 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 41 | ||||
-rw-r--r-- | src/main/scala/scala/async/FutureSystem.scala | 9 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 91 |
7 files changed, 282 insertions, 96 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index f52bdad..a2d21f6 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -10,7 +10,9 @@ import scala.reflect.macros.Context private[async] final case class AnfTransform[C <: Context](c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ def apply(tree: Tree): List[Tree] = { @@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](c: C) { * This step is needed to allow us to safely merge blocks during the `inline` transform below. */ private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Name] = tree.collect { - case dt: DefTree => dt.symbol.name - }.groupBy(x => x).filter(_._2.size > 1).keySet + val repeatedNames: Set[Symbol] = { + class DuplicateNameTraverser extends AsyncTraverser { + val result = collection.mutable.Buffer[Symbol]() + + override def traverse(tree: Tree) { + tree match { + case dt: DefTree => result += dt.symbol + case _ => super.traverse(tree) + } + } + } + val dupNameTraverser = new DuplicateNameTraverser + dupNameTraverser.traverse(tree) + dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol] + } /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] @@ -40,7 +54,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { override def transform(tree: Tree): Tree = { tree match { - case defTree: DefTree if repeatedNames(defTree.symbol.name) => + case defTree: DefTree if repeatedNames(defTree.symbol) => val trans = super.transform(defTree) val origName = defTree.symbol.name val sym = defTree.symbol.asInstanceOf[symtab.Symbol] @@ -54,6 +68,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) { trans match { case ValDef(mods, name, tpt, rhs) => treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case Bind(name, body) => + treeCopy.Bind(trans, newName, body) case DefDef(mods, name, tparams, vparamss, tpt, rhs) => treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) case TypeDef(mods, name, tparams, rhs) => @@ -79,12 +95,16 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object trace { private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${ + indentString + }$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -139,10 +159,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } - def transformToList(trees: List[Tree]): List[Tree] = trees match { - case fst :: rest => transformToList(fst) ++ transformToList(rest) - case Nil => Nil - } + def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList def transformToBlock(tree: Tree): Block = transformToList(tree) match { case stats :+ expr => Block(stats, expr) @@ -194,20 +211,40 @@ private[async] final case class AnfTransform[C <: Context](c: C) { stats :+ attachCopy(tree)(Assign(lhs, expr)) case If(cond, thenp, elsep) if containsAwait => - val stats :+ expr = inline.transformToList(cond) + val condStats :+ condExpr = inline.transformToList(cond) val thenBlock = inline.transformToBlock(thenp) val elseBlock = inline.transformToBlock(elsep) - stats :+ - c.typeCheck(attachCopy(tree)(If(expr, thenBlock, elseBlock))) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe + condStats :+ + attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) case Match(scrut, cases) if containsAwait => val scrutStats :+ scrutExpr = inline.transformToList(scrut) val caseDefs = cases map { case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. val block = inline.transformToBlock(body) - attachCopy(tree)(CaseDef(pat, guard, block)) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) + val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) + (vd, (b.symbol, newName)) + }).unzip + val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] + attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) } - scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs))) + // Refer to comments the translation of `If` above. + val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe + val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) + scrutStats :+ typedMatch case LabelDef(name, params, rhs) if containsAwait => List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) @@ -221,4 +258,5 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index ef506a5..4a770ed 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -11,6 +11,7 @@ import scala.reflect.macros.Context * @author Philipp Haller */ object Async extends AsyncBase { + import scala.concurrent.Future lazy val futureSystem = ScalaConcurrentFutureSystem @@ -65,7 +66,6 @@ abstract class AsyncBase { def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) val anaylzer = AsyncAnalysis[c.type](c) val utils = TransformUtils[c.type](c) import utils.{name, defn} @@ -87,54 +87,82 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - anaylzer.valDefsUsedInSubsequentStates(anfTree).map { + anaylzer.defTreesUsedInSubsequentStates(anfTree).map { vd => - (vd.symbol, name.fresh(vd.name)) + (vd.symbol, name.fresh(vd.name.toTermName)) }.toMap } + val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => + DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } - val onCompleteHandler = asyncBlock.onCompleteHandler + val onCompleteHandler = { + Function( + List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), + asyncBlock.onCompleteHandler) + } val resumeFunTree = asyncBlock.resumeFunTree[T] - val prom: Expr[futureSystem.Prom[T]] = reify { - // Create the empty promise - val result$async = futureSystemOps.createProm[T].splice - // Initialize the state - var state$async = 0 - // Resolve the execution context - val execContext$async = futureSystemOps.execContext.splice - var onCompleteHandler$async: util.Try[Any] => Unit = null - - // Spawn a future to: - futureSystemOps.future[Unit] { - c.Expr[Unit](Block( - // define vars for all intermediate results that are accessed from multiple states - localVarTrees :+ - // define the resume() method - resumeFunTree :+ - // assign onComplete function. (The var breaks the circular dependency with resume)` - Assign(Ident(name.onCompleteHandler), onCompleteHandler), - // and get things started by calling resume() - Apply(Ident(name.resume), Nil))) - }(c.Expr[futureSystem.ExecContext](Ident(name.execContext))).splice - // Return the promise from this reify block... - result$async + val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + lazy val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) + val applyDefDef: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + val applyBody = asyncBlock.onCompleteHandler + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody) + } + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + val applyBody = asyncBlock.onCompleteHandler + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + ClassDef(NoMods, name.stateMachineT, Nil, template) } - // ... and return its Future from the macro. - val result = futureSystemOps.promiseToFuture(prom) - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${result.tree}") + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + def spawn(tree: Tree): Tree = + futureSystemOps.future(c.Expr[Unit](tree))(c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree + + val code: c.Expr[futureSystem.Fut[T]] = { + val isSimple = asyncStates.size == 1 + val tree = + if (isSimple) + Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` + else { + Block(List[Tree]( + stateMachine, + ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), + spawn(Apply(selectStateMachine(name.apply), Nil)) + ), + futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + c.Expr[futureSystem.Fut[T]](tree) + } - result + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") + code } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { @@ -151,3 +179,10 @@ abstract class AsyncBase { states foreach (s => AsyncUtils.vprintln(s)) } } + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result$async: Result + + def execContext$async: EC +} diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 645d9f5..8bb5bcd 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -11,6 +11,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { import c.universe._ val utils = TransformUtils[c.type](c) + import utils._ /** @@ -30,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { * * Must be called on the ANF transformed tree. */ - def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { val analyzer = new AsyncDefinitionUseAnalyzer analyzer.traverse(tree) - analyzer.valDefsToLift.toList + val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct + liftable } private class UnsupportedAwaitAnalyzer extends AsyncTraverser { @@ -41,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + if (classDef.symbol.asClass.isCaseClass) + c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") } } @@ -70,12 +73,9 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) - case If(cond, _, _) if containsAwait => - reportUnsupportedAwait(cond, "condition") - super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case _ => super.traverse(tree) } } @@ -92,7 +92,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty - } + } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { @@ -102,40 +102,67 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - val valDefsToLift = mutable.Set[ValDef]() + val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() + val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() + + override def nestedMethod(defDef: DefDef) { + nestedMethodsToLift += defDef + defDef.rhs foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested methods. + case _ => + } + case _ => + } + } + + override def function(function: Function) { + function foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested functions. + case _ => + } + case _ => + } + } override def traverse(tree: Tree) = { tree match { - case If(cond, thenp, elsep) if tree exists isAwait => + case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => + case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) case LabelDef(name, params, rhs) if rhs exists isAwait => traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => + case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + val isPatternBinder = vd.name.toString.contains(name.bindSuffix) + if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd + case as: Assign => if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) + assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } @@ -145,4 +172,5 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala index 87a63d7..999cb95 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/AsyncUtils.scala @@ -10,8 +10,8 @@ object AsyncUtils { private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") - private val verbose = enabled("debug") - private val trace = enabled("trace") + var verbose = enabled("debug") + var trace = enabled("trace") private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index d9faad5..7b4ccb8 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -6,11 +6,12 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable +import language.existentials /* * @author Philipp Haller */ -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS) { +private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { builder => val utils = TransformUtils[c.type](c) @@ -70,7 +71,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree + c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree mkHandlerCase(state, stats :+ callOnComplete) } @@ -96,12 +97,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None - private def renameReset(tree: Tree) = resetDuplicate(substituteNames(tree, nameMap)) + private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) def +=(stat: c.Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") def addStat() = stats += renameReset(stat) stat match { + case _: DefDef => // these have been lifted. case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) @@ -146,7 +148,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { + case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) + case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") + } + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state this += Match(renameReset(scrutTree), newCases) @@ -237,7 +244,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState) + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) asyncStates ++= builder.asyncStates } @@ -302,19 +311,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val initStates = asyncStates.init /** - * lazy val onCompleteHandler = (tr: Try[Any]) => state match { + * // assumes tr: Try[Any] is in scope. + * // + * state match { * case 0 => { * x11 = tr.get.asInstanceOf[Double]; * state = 1; * resume() * } */ - val onCompleteHandler: Tree = { - val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList - Function( - List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), - Match(Ident(name.state), onCompleteHandlers)) - } + val onCompleteHandler: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler).toList) /** * def resume(): Unit = { @@ -346,9 +352,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } } + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) + private val internalSyms = origTree.collect { + case dt: DefTree => dt.symbol + } + + private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) private def mkResumeApply = Apply(Ident(name.resume), Nil) diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index 20bbea3..e9373b3 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -33,6 +33,9 @@ trait FutureSystem { /** Lookup the execution context, typically with an implicit search */ def execContext: Expr[ExecContext] + def promType[A: WeakTypeTag]: Type + def execContextType: Type + /** Create an empty promise */ def createProm[A: WeakTypeTag]: Expr[Prom[A]] @@ -71,6 +74,9 @@ object ScalaConcurrentFutureSystem extends FutureSystem { case context => context }) + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] + def execContextType: Type = c.weakTypeOf[ExecutionContext] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { Promise[A]() } @@ -113,6 +119,9 @@ object IdentityFutureSystem extends FutureSystem { def execContext: Expr[ExecContext] = c.literalUnit + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] + def execContextType: Type = c.weakTypeOf[Unit] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { new Prom(null.asInstanceOf[A]) } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index c5bbba1..5b1fcbe 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -18,18 +18,18 @@ private[async] final case class TransformUtils[C <: Context](c: C) { def suffixedName(prefix: String) = newTermName(suffix(prefix)) - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - - // TODO do we need to freshen any of these? - val tr = newTermName("tr") - val onCompleteHandler = suffixedName("onCompleteHandler") - - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val apply = newTermName("apply") + val tr = newTermName("tr") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) @@ -127,6 +127,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree)) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_))) + object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) @@ -146,8 +154,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { self.splice.get } - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - + val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) val TryClass = c.mirror.staticClass("scala.util.Try") val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") @@ -159,7 +166,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { val tree2: Tree = c.typeCheck(apply.tree) @@ -181,4 +187,59 @@ private[async] final case class TransformUtils[C <: Context](c: C) { tree } + def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) = + new ResetInternalAttrs(internalSyms.toSet).transform(tree) + + /** + * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]] + * + * A transformer which resets symbol and tpe fields of all nodes in a given tree, + * with special treatment of: + * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset + * to empty if it started out empty or refers to local symbols (which are erased). + * `TypeApply` nodes: are deleted if type arguments end up reverted to empty + * + * `This` and `Ident` nodes referring to an external symbol are ''not'' reset. + */ + private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer { + + import language.existentials + + override def transform(tree: Tree): Tree = super.transform { + def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol) + + tree match { + case tpt: TypeTree => resetTypeTree(tpt) + case TypeApply(fn, args) + if args map transform exists (_.isEmpty) => transform(fn) + case EmptyTree => tree + case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block + case _ => resetTree(tree) + } + } + + private def resetTypeTree(tpt: TypeTree): Tree = { + if (tpt.original != null) + transform(tpt.original) + else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) { + val dupl = tpt.duplicate + dupl.tpe = null + dupl + } + else tpt + } + + private def resetTree(tree: Tree): Tree = { + val hasSymbol: Boolean = { + val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}] + reflectInternalTree.hasSymbol + } + val dupl = tree.duplicate + if (hasSymbol) + dupl.symbol = NoSymbol + dupl.tpe = null + dupl + } + } + } |