From 063492a14a004cb519f553a6cd30f8b3e41f0453 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 15 Dec 2014 17:34:51 +1000 Subject: Make `f(await(completedFuture))` execute `f` synchronously A worthy optimization, suggested by @danarmak. Closes #73 --- .../scala/scala/async/internal/ExprBuilder.scala | 76 ++++++++++++---------- .../scala/scala/async/internal/FutureSystem.scala | 15 +++++ .../scala/async/run/SyncOptimizationSpec.scala | 28 ++++++++ .../scala/scala/async/run/futures/FutureSpec.scala | 1 - 4 files changed, 85 insertions(+), 35 deletions(-) create mode 100644 src/test/scala/scala/async/run/SyncOptimizationSpec.scala (limited to 'src') diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index fc82f4c..2dd485d 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -27,7 +27,7 @@ trait ExprBuilder { def nextStates: List[Int] - def mkHandlerCaseForState: CaseDef + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None @@ -51,7 +51,7 @@ trait ExprBuilder { def nextStates: List[Int] = List(nextState) - def mkHandlerCaseForState: CaseDef = + def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup)) override val toString: String = @@ -62,7 +62,7 @@ trait ExprBuilder { * a branch of an `if` or a `match`. */ final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { - override def mkHandlerCaseForState: CaseDef = + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats) override val toString: String = @@ -79,39 +79,47 @@ trait ExprBuilder { def nextStates: List[Int] = List(nextState) - override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), - Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), callOnComplete, Return(literalUnit))) + override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { + val fun = This(tpnme.EMPTY) + val callOnComplete = futureSystemOps.onComplete[Any, Unit](Expr[futureSystem.Fut[Any]](awaitable.expr), + Expr[futureSystem.Tryy[Any] => Unit](fun), Expr[futureSystem.ExecContext](Ident(name.execContext))).tree + val tryGetOrCallOnComplete = + if (futureSystemOps.continueCompletedFutureOnSameThread) + If(futureSystemOps.isCompleted(Expr[futureSystem.Fut[_]](awaitable.expr)).tree, + Block(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil, literalUnit), + Block(callOnComplete :: Nil, Return(literalUnit))) + else + Block(callOnComplete :: Nil, Return(literalUnit)) + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete)) } + private def tryGetTree(tryReference: => Tree) = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(futureSystemOps.tryyGet[Any](Expr[futureSystem.Tryy[Any]](tryReference)).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * = tr.get.asInstanceOf[] + * + * + * } + */ + def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) = + If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](tryReference)).tree, + Block(futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[futureSystem.Tryy[T]]( + TypeApply(Select(tryReference, newTermName("asInstanceOf")), + List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, + Return(literalUnit)), + Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup)) + ) + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { - val tryGetTree = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(futureSystemOps.tryyGet[T](Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result.complete(tr.asInstanceOf[Try[T]]) - * else { - * = tr.get.asInstanceOf[] - * - * - * } - */ - val ifIsFailureTree = - If(futureSystemOps.tryyIsFailure(Expr[futureSystem.Tryy[T]](Ident(symLookup.applyTrParam))).tree, - Block(futureSystemOps.completeProm[T]( - Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), - Expr[futureSystem.Tryy[T]]( - TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), - List(TypeTree(futureSystemOps.tryType[T]))))).tree :: Nil, - Return(literalUnit)), - Block(List(tryGetTree), mkStateTree(nextState, symLookup)) - ) - - Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree))) + Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam))))) } override val toString: String = @@ -337,7 +345,7 @@ trait ExprBuilder { case s :: Nil => List(caseForLastState) case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState[T] initCases :+ caseForLastState } } diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index 46c0bcf..96356ed 100644 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -49,6 +49,12 @@ trait FutureSystem { def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], execContext: Expr[ExecContext]): Expr[Unit] + def continueCompletedFutureOnSameThread = false + def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = + throw new UnsupportedOperationException("isCompleted not supported by this FutureSystem") + def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = + throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem") + /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] @@ -103,6 +109,15 @@ object ScalaConcurrentFutureSystem extends FutureSystem { future.splice.onComplete(fun.splice)(execContext.splice) } + override def continueCompletedFutureOnSameThread: Boolean = true + + override def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = reify { + future.splice.isCompleted + } + override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify { + future.splice.value.get + } + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.complete(value.splice) Expr[Unit](Literal(Constant(()))).splice diff --git a/src/test/scala/scala/async/run/SyncOptimizationSpec.scala b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala new file mode 100644 index 0000000..dd649f4 --- /dev/null +++ b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala @@ -0,0 +1,28 @@ +package scala.async.run + +import org.junit.Test +import scala.async.Async._ +import scala.concurrent._ +import scala.concurrent.duration._ +import ExecutionContext.Implicits._ + +class SyncOptimizationSpec { + @Test + def awaitOnCompletedFutureRunsOnSameThread: Unit = { + + def stackDepth = Thread.currentThread().getStackTrace.size + + val future = async { + val thread1 = Thread.currentThread + val stackDepth1 = stackDepth + + val f = await(Future.successful(1)) + val thread2 = Thread.currentThread + val stackDepth2 = stackDepth + assert(thread1 == thread2) + assert(stackDepth1 == stackDepth2) + } + Await.result(future, 10.seconds) + } + +} diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala index 25be0b1..362303e 100644 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ b/src/test/scala/scala/async/run/futures/FutureSpec.scala @@ -538,7 +538,6 @@ class FutureSpec { val f = async { await(future(5)) / 0 } Await.ready(f, defaultTimeout).value.get.toString mustBe expected.toString } - } -- cgit v1.2.3