diff options
Diffstat (limited to 'src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala')
-rw-r--r-- | src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala index a669cfa..7abc6e8 100644 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -22,21 +22,28 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { /* Implements `async { ... }` using the CPS plugin. */ - protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - def lookupMember(name: String) = { - val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") + def lookupClassMember(clazz: String, name: String) = { + val asyncTrait = c.mirror.staticClass(clazz) val tpe = asyncTrait.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") + } + def lookupObjectMember(clazz: String, name: String) = { + val moduleClass = c.mirror.staticModule(clazz).moduleClass + val tpe = moduleClass.asType.toType + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") } AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") - val utils = TransformUtils[c.type](c) - val futureSystemOps = futureSystem.mkOps(c) - val awaitSym = utils.defn.Async_await - val awaitFallbackSym = lookupMember("awaitFallback") + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val futureSystemOps = futureSystem.mkOps(symTab) + val awaitSym = lookupObjectMember("scala.async.Async", "await") + val awaitFallbackSym = lookupClassMember("scala.async.continuations.AsyncBaseWithCPSFallback", "awaitFallback") // replace `await` invocations with `awaitFallback` invocations val awaitReplacer = new Transformer { @@ -60,10 +67,12 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { }.asInstanceOf[Future[T]] */ + def spawn(expr: Tree) = futureSystemOps.spawn(expr.asInstanceOf[futureSystemOps.universe.Tree], execContext.tree.asInstanceOf[futureSystemOps.universe.Tree]).asInstanceOf[Tree] + val bodyWithFuture = { val tree = bodyWithAwaitFallback match { - case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) - case expr => futureSystemOps.spawn(expr) + case Block(stmts, expr) => Block(stmts, spawn(expr)) + case expr => spawn(expr) } c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) } @@ -71,20 +80,22 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { reset { bodyWithFuture.splice } } - val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) + val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset.asInstanceOf[futureSystemOps.universe.Expr[futureSystem.Fut[Any]]]).asInstanceOf[c.Expr[futureSystem.Fut[T]]] AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") bodyWithCast } - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") - val analyzer = AsyncAnalysis[c.type](c, this) + val asyncMacro = AsyncMacro(c, futureSystem) - if (!analyzer.reportUnsupportedAwaits(body.tree)) - super.asyncImpl[T](c)(body) // no unsupported awaits + if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled)) + super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits else - cpsBasedAsyncImpl[T](c)(body) // fallback to CPS + cpsBasedAsyncImpl[T](c)(body)(execContext) // fallback to CPS } } |