From 1333b3837d405c31baaa44d1db89aab0f7d09349 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 11 Dec 2014 21:30:35 +1000 Subject: Avoid unbounded stack consumption for synchronous control flow Previously, as sequence of state transitions that did not pass through an asynchrous boundary incurred stack frames. The trivial loop in the enclosed test case would then overflow the stack. This commit merges the `resume` and `apply(tr: Try[Any])` methods into a `apply`. It changes the body of this method to be an infinite loop with returns at the terminal points in the state machine (or at a terminal failure.) To allow merging of these previously separate matches, states that contain an await are now allocated two state ids: one for the setup code that calls `onComplete`, and one for the code in the continuation that records the result and advances the state machine. Fixes #93 --- .../scala/async/internal/AsyncTransform.scala | 24 +++------ .../scala/scala/async/internal/ExprBuilder.scala | 61 +++++++++++++--------- .../scala/scala/async/internal/StateAssigner.scala | 12 +++-- .../scala/async/internal/TransformUtils.scala | 5 +- src/test/scala/scala/async/TreeInterrogation.scala | 4 +- .../scala/scala/async/run/futures/FutureSpec.scala | 7 +++ .../run/stackoverflow/StackOverflowSpec.scala | 28 ++++++++++ 7 files changed, 87 insertions(+), 54 deletions(-) create mode 100644 src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala (limited to 'src') diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 7d56043..c7a0c65 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -24,31 +24,29 @@ trait AsyncTransform { val anfTree = futureSystemOps.postAnfTransform(anfTree0) - val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) - val applyDefDefDummyBody: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree))) - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit) } // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. val stateMachine: ClassDef = { val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial))) val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) val apply0DefDef: DefDef = { // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. // See SI-1247 for the the optimization that avoids creatio - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) } val extraValDef: ValDef = { // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. // See SI-1247 for the the optimization that avoids creatio - ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), Literal(Constant(()))) + ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), literalUnit) } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef, extraValDef) + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef, extraValDef) } val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) @@ -90,8 +88,7 @@ trait AsyncTransform { val stateMachineSpliced: Tree = spliceMethodBodies( liftedFields, stateMachine, - atMacroPos(asyncBlock.onCompleteHandler[T]), - atMacroPos(asyncBlock.resumeFunTree[T].rhs) + atMacroPos(asyncBlock.onCompleteHandler[T]) ) def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) @@ -131,10 +128,9 @@ trait AsyncTransform { * @param liftables trees of definitions that are lifted to fields of the state machine class * @param tree `ClassDef` tree of the state machine class * @param applyBody tree of onComplete handler (`apply` method) - * @param resumeBody RHS of definition tree of `resume` method * @return transformed `ClassDef` tree of the state machine class */ - def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = { + def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree): Tree = { val liftedSyms = liftables.map(_.symbol).toSet val stateMachineClass = tree.symbol liftedSyms.foreach { @@ -211,12 +207,6 @@ trait AsyncTransform { (ctx: analyzer.Context) => val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) typedTree - - case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => - (ctx: analyzer.Context) => - val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) - val res = fixup(dd, changed, ctx) - res } result } diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 5314ae0..fc82f4c 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -52,7 +52,7 @@ trait ExprBuilder { List(nextState) def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup)) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -72,7 +72,7 @@ trait ExprBuilder { /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nextState`. */ - final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int, + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int, val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { @@ -82,7 +82,7 @@ trait ExprBuilder { override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), callOnComplete, Return(literalUnit))) } override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { @@ -102,15 +102,16 @@ trait ExprBuilder { */ val ifIsFailureTree = If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, - futureSystemOps.completeProm[T]( + Block(futureSystemOps.completeProm[T]( Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), Expr[futureSystem.Tryy[T]]( TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), - List(TypeTree(futureSystemOps.tryType[T]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) + List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, + Return(literalUnit)), + Block(List(tryGetTree), mkStateTree(nextState, symLookup)) ) - Some(mkHandlerCase(state, List(ifIsFailureTree))) + Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree))) } override val toString: String = @@ -146,9 +147,10 @@ trait ExprBuilder { } def resultWithAwait(awaitable: Awaitable, + onCompleteState: Int, nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { @@ -157,7 +159,7 @@ trait ExprBuilder { } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { - def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + def mkBranch(state: Int) = mkStateTree(state, symLookup) this += If(condTree, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) } @@ -177,7 +179,7 @@ trait ExprBuilder { val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + CaseDef(pat, guard, Block(bindAssigns, mkStateTree(caseStates(num), symLookup))) } // 2. insert changed match tree at the end of the current state this += Match(scrutTree, newCases) @@ -185,7 +187,7 @@ trait ExprBuilder { } def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { - this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + this += mkStateTree(startLabelState, symLookup) new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) } @@ -226,9 +228,10 @@ trait ExprBuilder { for (stat <- stats) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val onCompleteState = nextState() val afterAwaitState = nextState() val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) - asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) @@ -296,8 +299,6 @@ trait ExprBuilder { def asyncStates: List[AsyncState] def onCompleteHandler[T: WeakTypeTag]: Tree - - def resumeFunTree[T: WeakTypeTag]: DefDef } case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { @@ -330,7 +331,7 @@ trait ExprBuilder { val lastStateBody = Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody)) - mkHandlerCase(lastState.state, rhs.tree) + mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit))) } asyncStates.toList match { case s :: Nil => @@ -362,18 +363,23 @@ trait ExprBuilder { * } * } */ - def resumeFunTree[T: WeakTypeTag]: DefDef = - DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + private def resumeFunTree[T: WeakTypeTag]: Tree = Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), List( CaseDef( Bind(name.t, Ident(nme.WILDCARD)), Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { val t = Expr[Throwable](Ident(name.t)) - futureSystemOps.completeProm[T]( + val complete = futureSystemOps.completeProm[T]( Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - })), EmptyTree)) + Block(complete :: Nil, Return(literalUnit)) + })), EmptyTree) + + def forever(t: Tree): Tree = { + val labelName = name.fresh("while$") + LabelDef(labelName, Nil, Block(t :: Nil, Apply(Ident(labelName), Nil))) + } /** * Builds a `match` expression used as an onComplete handler. @@ -387,8 +393,12 @@ trait ExprBuilder { * resume() * } */ - def onCompleteHandler[T: WeakTypeTag]: Tree = - Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + def onCompleteHandler[T: WeakTypeTag]: Tree = { + val onCompletes = initStates.flatMap(_.mkOnCompleteHandler[T]).toList + forever { + Block(resumeFunTree :: Nil, literalUnit) + } + } } } @@ -399,9 +409,6 @@ trait ExprBuilder { case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - private def mkResumeApply(symLookup: SymLookup) = - Apply(symLookup.memberRef(name.resume), Nil) - private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) @@ -411,5 +418,7 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = CaseDef(Literal(Constant(num)), EmptyTree, rhs) - private def literalUnit = Literal(Constant(())) + def literalUnit = Literal(Constant(())) + + def literalNull = Literal(Constant(null)) } diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index 8f0d518..55e7a51 100644 --- a/src/main/scala/scala/async/internal/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -5,10 +5,12 @@ package scala.async.internal private[async] final class StateAssigner { - private var current = -1 + private var current = StateAssigner.Initial - def nextState(): Int = { - current += 1 - current - } + def nextState(): Int = + try current finally current += 1 } + +object StateAssigner { + final val Initial = 0 +} \ No newline at end of file diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 5e73a7f..0b8cd00 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -49,6 +49,7 @@ private[async] trait TransformUtils { private def isByName(fun: Tree): ((Int, Int) => Boolean) = { if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else if (fun.tpe == null) (x, y) => false else { val paramss = fun.tpe.paramss val byNamess = paramss.map(_.map(_.isByNameParam)) @@ -72,10 +73,6 @@ private[async] trait TransformUtils { self.splice.contains(elem.splice) } - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { self.splice == other.splice } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 8261898..b7c403a 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -36,7 +36,7 @@ class TreeInterrogation { functions.size mustBe 1 val varDefs = tree1.collect { - case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name + case vd @ ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) && vd.symbol.owner.isClass => name } varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1$1", "await$2$1")) @@ -49,7 +49,7 @@ class TreeInterrogation { && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name } }.flatten - defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume", "")) + defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "")) } } diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala index 1761db5..25be0b1 100644 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ b/src/test/scala/scala/async/run/futures/FutureSpec.scala @@ -134,6 +134,13 @@ class FutureSpec { Await.result(future1, defaultTimeout) mustBe ("10-14") intercept[NoSuchElementException] { Await.result(future2, defaultTimeout) } } + + @Test def mini() { + val future4 = async { + await(Future.successful(0)).toString + } + Await.result(future4, defaultTimeout) + } @Test def `recover from exceptions`() { val future1 = Future(5) diff --git a/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala b/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala new file mode 100644 index 0000000..2dc9b92 --- /dev/null +++ b/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala @@ -0,0 +1,28 @@ +/* + * Copyright (C) 2012-2014 Typesafe Inc. + */ + +package scala.async +package run +package stackoverflow + +import org.junit.Test +import scala.async.internal.AsyncId + + +class StackOverflowSpec { + + @Test + def stackSafety() { + import AsyncId._ + async { + var i = 100000000 + while (i > 0) { + if (false) { + await(()) + } + i -= 1 + } + } + } +} -- cgit v1.2.3 From c0d711570b76fb17f8f58a80f8529dcb9cbfdb2c Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 12 Dec 2014 07:42:36 +1000 Subject: Remove extraneous method in generated code. --- src/main/scala/scala/async/internal/AsyncTransform.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index c7a0c65..47a2704 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -38,15 +38,10 @@ trait AsyncTransform { val apply0DefDef: DefDef = { // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio + // See SI-1247 for the the optimization that avoids creation. DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) } - val extraValDef: ValDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - ValDef(NoMods, newTermName("extra"), TypeTree(definitions.UnitTpe), literalUnit) - } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef, extraValDef) + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) } val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) -- cgit v1.2.3 From 063492a14a004cb519f553a6cd30f8b3e41f0453 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 15 Dec 2014 17:34:51 +1000 Subject: Make `f(await(completedFuture))` execute `f` synchronously A worthy optimization, suggested by @danarmak. Closes #73 --- .../scala/scala/async/internal/ExprBuilder.scala | 76 ++++++++++++---------- .../scala/scala/async/internal/FutureSystem.scala | 15 +++++ .../scala/async/run/SyncOptimizationSpec.scala | 28 ++++++++ .../scala/scala/async/run/futures/FutureSpec.scala | 1 - 4 files changed, 85 insertions(+), 35 deletions(-) create mode 100644 src/test/scala/scala/async/run/SyncOptimizationSpec.scala (limited to 'src') diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index fc82f4c..2dd485d 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -27,7 +27,7 @@ trait ExprBuilder { def nextStates: List[Int] - def mkHandlerCaseForState: CaseDef + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None @@ -51,7 +51,7 @@ trait ExprBuilder { def nextStates: List[Int] = List(nextState) - def mkHandlerCaseForState: CaseDef = + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup)) override val toString: String = @@ -62,7 +62,7 @@ trait ExprBuilder { * a branch of an `if` or a `match`. */ final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { - override def mkHandlerCaseForState: CaseDef = + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats) override val toString: String = @@ -79,39 +79,47 @@ trait ExprBuilder { def nextStates: List[Int] = List(nextState) - override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), - Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), callOnComplete, Return(literalUnit))) + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { + val fun = This(tpnme.EMPTY) + val callOnComplete = futureSystemOps.onComplete[Any, Unit](Expr[futureSystem.Fut[Any]](awaitable.expr), + Expr[futureSystem.Tryy[Any] => Unit](fun), Expr[futureSystem.ExecContext](Ident(name.execContext))).tree + val tryGetOrCallOnComplete = + if (futureSystemOps.continueCompletedFutureOnSameThread) + If(futureSystemOps.isCompleted(Expr[futureSystem.Fut[_]](awaitable.expr)).tree, + Block(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil, literalUnit), + Block(callOnComplete :: Nil, Return(literalUnit))) + else + Block(callOnComplete :: Nil, Return(literalUnit)) + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete)) } + private def tryGetTree(tryReference: => Tree) = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(futureSystemOps.tryyGet[Any](Expr[futureSystem.Tryy[Any]](tryReference)).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * = tr.get.asInstanceOf[] + * + * + * } + */ + def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) = + If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](tryReference)).tree, + Block(futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[futureSystem.Tryy[T]]( + TypeApply(Select(tryReference, newTermName("asInstanceOf")), + List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, + Return(literalUnit)), + Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup)) + ) + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { - val tryGetTree = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(futureSystemOps.tryyGet[T](Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result.complete(tr.asInstanceOf[Try[T]]) - * else { - * = tr.get.asInstanceOf[] - * - * - * } - */ - val ifIsFailureTree = - If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, - Block(futureSystemOps.completeProm[T]( - Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), - Expr[futureSystem.Tryy[T]]( - TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), - List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, - Return(literalUnit)), - Block(List(tryGetTree), mkStateTree(nextState, symLookup)) - ) - - Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree))) + Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam))))) } override val toString: String = @@ -337,7 +345,7 @@ trait ExprBuilder { case s :: Nil => List(caseForLastState) case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState[T] initCases :+ caseForLastState } } diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index 46c0bcf..96356ed 100644 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -49,6 +49,12 @@ trait FutureSystem { def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], execContext: Expr[ExecContext]): Expr[Unit] + def continueCompletedFutureOnSameThread = false + def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = + throw new UnsupportedOperationException("isCompleted not supported by this FutureSystem") + def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = + throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem") + /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] @@ -103,6 +109,15 @@ object ScalaConcurrentFutureSystem extends FutureSystem { future.splice.onComplete(fun.splice)(execContext.splice) } + override def continueCompletedFutureOnSameThread: Boolean = true + + override def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = reify { + future.splice.isCompleted + } + override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify { + future.splice.value.get + } + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.complete(value.splice) Expr[Unit](Literal(Constant(()))).splice diff --git a/src/test/scala/scala/async/run/SyncOptimizationSpec.scala b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala new file mode 100644 index 0000000..dd649f4 --- /dev/null +++ b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala @@ -0,0 +1,28 @@ +package scala.async.run + +import org.junit.Test +import scala.async.Async._ +import scala.concurrent._ +import scala.concurrent.duration._ +import ExecutionContext.Implicits._ + +class SyncOptimizationSpec { + @Test + def awaitOnCompletedFutureRunsOnSameThread: Unit = { + + def stackDepth = Thread.currentThread().getStackTrace.size + + val future = async { + val thread1 = Thread.currentThread + val stackDepth1 = stackDepth + + val f = await(Future.successful(1)) + val thread2 = Thread.currentThread + val stackDepth2 = stackDepth + assert(thread1 == thread2) + assert(stackDepth1 == stackDepth2) + } + Await.result(future, 10.seconds) + } + +} diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala index 25be0b1..362303e 100644 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ b/src/test/scala/scala/async/run/futures/FutureSpec.scala @@ -538,7 +538,6 @@ class FutureSpec { val f = async { await(future(5)) / 0 } Await.ready(f, defaultTimeout).value.get.toString mustBe expected.toString } - } -- cgit v1.2.3