diff options
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 90 |
1 files changed, 88 insertions, 2 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 07aa1ee..1ca9e8f 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -310,7 +310,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val // when adding assignment need to take `toRename` into account stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) - case If(cond, thenp, elsep) => + case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) val ifBudget: Int = remainingBudget / 2 @@ -335,7 +335,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case Match(scrutinee, cases) => + case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) val matchBudget: Int = remainingBudget / 2 @@ -395,6 +395,92 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } } + private val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + def isByName(fun: Tree): (Int => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) i => true + else fun.tpe match { + case MethodType(params, _) => + val isByNameParams = params.map(_.asTerm.isByNameParam) + (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) + case _ => Map() + } + } + + private def isAwait(fun: Tree) = { + fun.symbol == defn.Async_await + } + + private[async] class LiftableVarTraverser extends Traverser { + var blockId = 0 + var valDefBlockId = Map[Symbol, (ValDef, Int)]() + val liftable = collection.mutable.Set[ValDef]() + + + def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if rt.symbol == Async_await => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + + override def traverse(tree: Tree) = tree match { + case cd: ClassDef => + val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(tree, s"nested ${kind}") + case md: ModuleDef => + reportUnsupportedAwait(tree, "nested object") + case _: Function => + reportUnsupportedAwait(tree, "nested anonymous function") + case If(cond, thenp, elsep) if tree exists isAwait => + traverse(cond) + blockId += 1 + traverse(thenp) + blockId += 1 + traverse(elsep) + blockId += 1 + case Match(selector, cases) if tree exists isAwait => + traverse(selector) + blockId += 1 + cases foreach {c => traverse(c); blockId += 1} + case Apply(fun, args) if isAwait(fun) => + traverseTrees(args) + traverse(fun) + blockId += 1 + case Apply(fun, args) => + val isInByName = isByName(fun) + for ((arg, index) <- args.zipWithIndex) { + if (!isInByName(index)) traverse(arg) + else reportUnsupportedAwait(arg, "by-name argument") + } + traverse(fun) + case vd: ValDef => + super.traverse(tree) + valDefBlockId += (vd.symbol -> (vd, blockId)) + if (vd.rhs.symbol == Async_await) liftable += vd + case as: Assign => + if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 + case rt: RefTree => + valDefBlockId.get(rt.symbol) match { + case Some((vd, defBlockId)) if defBlockId != blockId => + liftable += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? |