From b2472a5f057ca5e18dc8b3205600c14bf31dc1d1 Mon Sep 17 00:00:00 2001 From: Aleksandar Prokopec Date: Mon, 30 Apr 2012 16:32:26 +0200 Subject: Fixed a bug with setting execution contexts. Ported most of the future tests. --- .../scala/concurrent/ConcurrentPackageObject.scala | 9 +- .../scala/concurrent/ExecutionContext.scala | 11 +- .../concurrent/impl/ExecutionContextImpl.scala | 20 +-- src/library/scala/concurrent/impl/Promise.scala | 10 +- src/library/scala/concurrent/package.scala | 10 +- src/library/scala/concurrent/util/Duration.scala | 58 ++++----- test/files/jvm/future-spec/FutureTests.scala | 140 ++++++++++++++++++++- 7 files changed, 201 insertions(+), 57 deletions(-) diff --git a/src/library/scala/concurrent/ConcurrentPackageObject.scala b/src/library/scala/concurrent/ConcurrentPackageObject.scala index fafd7fd238..a3d985733e 100644 --- a/src/library/scala/concurrent/ConcurrentPackageObject.scala +++ b/src/library/scala/concurrent/ConcurrentPackageObject.scala @@ -22,7 +22,7 @@ abstract class ConcurrentPackageObject { */ lazy val defaultExecutionContext = new impl.ExecutionContextImpl(null) - private val currentExecutionContext = new ThreadLocal[ExecutionContext] + val currentExecutionContext = new ThreadLocal[ExecutionContext] val handledFutureException: PartialFunction[Throwable, Throwable] = { case t: Throwable if isFutureThrowable(t) => t @@ -82,7 +82,7 @@ abstract class ConcurrentPackageObject { * - TimeoutException - in the case that the blockable object timed out */ def blocking[T](body: =>T): T = - blocking(impl.Future.body2awaitable(body), Duration.fromNanos(0)) + blocking(impl.Future.body2awaitable(body), Duration.Inf) /** Blocks on an awaitable object. * @@ -93,11 +93,12 @@ abstract class ConcurrentPackageObject { * - InterruptedException - in the case that a wait within the blockable object was interrupted * - TimeoutException - in the case that the blockable object timed out */ - def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = + def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = { currentExecutionContext.get match { - case null => Await.result(awaitable, atMost) + case null => awaitable.result(atMost)(Await.canAwaitEvidence) case ec => ec.internalBlockingCall(awaitable, atMost) } + } @inline implicit final def int2durationops(x: Int): DurationOps = new DurationOps(x) } diff --git a/src/library/scala/concurrent/ExecutionContext.scala b/src/library/scala/concurrent/ExecutionContext.scala index 3f62f58bf8..4666674b5b 100644 --- a/src/library/scala/concurrent/ExecutionContext.scala +++ b/src/library/scala/concurrent/ExecutionContext.scala @@ -47,11 +47,18 @@ object ExecutionContext { /** Creates an `ExecutionContext` from the given `ExecutorService`. */ - def fromExecutorService(e: ExecutorService): ExecutionContext with Executor = new impl.ExecutionContextImpl(e) + def fromExecutorService(e: ExecutorService, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter) /** Creates an `ExecutionContext` from the given `Executor`. */ - def fromExecutor(e: Executor): ExecutionContext with Executor = new impl.ExecutionContextImpl(e) + def fromExecutor(e: Executor, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter) + + def defaultReporter: Throwable => Unit = { + // `Error`s are currently wrapped by `resolver`. + // Also, re-throwing `Error`s here causes an exception handling test to fail. + //case e: Error => throw e + case t => t.printStackTrace() + } } diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala index ad98331241..7dc3ed2988 100644 --- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala +++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala @@ -17,7 +17,8 @@ import scala.concurrent.util.{ Duration } -private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext with Executor { +private[scala] class ExecutionContextImpl(es: AnyRef, reporter: Throwable => Unit = ExecutionContext.defaultReporter) +extends ExecutionContext with Executor { import ExecutionContextImpl._ val executorService: AnyRef = if (es eq null) getExecutorService else es @@ -26,7 +27,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w def executorsThreadFactory = new ThreadFactory { def newThread(r: Runnable) = new Thread(new Runnable { override def run() { - currentExecutionContext.set(ExecutionContextImpl.this) + scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this) r.run() } }) @@ -36,7 +37,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w def forkJoinPoolThreadFactory = new ForkJoinPool.ForkJoinWorkerThreadFactory { def newThread(fjp: ForkJoinPool) = new ForkJoinWorkerThread(fjp) { override def onStart() { - currentExecutionContext.set(ExecutionContextImpl.this) + scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this) } } } @@ -92,22 +93,13 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w } } - def reportFailure(t: Throwable) = t match { - // `Error`s are currently wrapped by `resolver`. - // Also, re-throwing `Error`s here causes an exception handling test to fail. - //case e: Error => throw e - case t => t.printStackTrace() - } + def reportFailure(t: Throwable) = reporter(t) } private[concurrent] object ExecutionContextImpl { - - private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] { - override protected def initialValue = null - } - + } diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala index ee1841aaff..e8d569fdba 100644 --- a/src/library/scala/concurrent/impl/Promise.scala +++ b/src/library/scala/concurrent/impl/Promise.scala @@ -78,7 +78,7 @@ object Promise { */ class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self => updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU - + protected final def tryAwait(atMost: Duration): Boolean = { @tailrec def awaitUnsafe(waitTimeNanos: Long): Boolean = { @@ -88,7 +88,7 @@ object Promise { val start = System.nanoTime() try { synchronized { - while (!isCompleted) wait(ms, ns) + if (!isCompleted) wait(ms, ns) // previously - this was a `while`, ending up in an infinite loop } } catch { case e: InterruptedException => @@ -99,7 +99,7 @@ object Promise { isCompleted } //FIXME do not do this if there'll be no waiting - blocking(Future.body2awaitable(awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)), atMost) + awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue) } @throws(classOf[TimeoutException]) @@ -147,7 +147,9 @@ object Promise { } tryComplete(resolveEither(value)) } finally { - synchronized { notifyAll() } //Notify any evil blockers + synchronized { //Notify any evil blockers + notifyAll() + } } } diff --git a/src/library/scala/concurrent/package.scala b/src/library/scala/concurrent/package.scala index b06c6f3c63..169826034b 100644 --- a/src/library/scala/concurrent/package.scala +++ b/src/library/scala/concurrent/package.scala @@ -26,9 +26,15 @@ package concurrent { object Await { private[concurrent] implicit val canAwaitEvidence = new CanAwait {} - def ready[T](awaitable: Awaitable[T], atMost: Duration): Awaitable[T] = awaitable.ready(atMost) + def ready[T <: Awaitable[_]](awaitable: T, atMost: Duration): T = { + blocking(awaitable, atMost) + awaitable + } + + def result[T](awaitable: Awaitable[T], atMost: Duration): T = { + blocking(awaitable, atMost) + } - def result[T](awaitable: Awaitable[T], atMost: Duration): T = awaitable.result(atMost) } final class DurationOps private[concurrent] (x: Int) { diff --git a/src/library/scala/concurrent/util/Duration.scala b/src/library/scala/concurrent/util/Duration.scala index c4e5fa491a..0e39fdffa3 100644 --- a/src/library/scala/concurrent/util/Duration.scala +++ b/src/library/scala/concurrent/util/Duration.scala @@ -115,13 +115,13 @@ object Duration { * Parse TimeUnit from string representation. */ protected[util] def timeUnit(unit: String): TimeUnit = unit.toLowerCase match { - case "d" | "day" | "days" ⇒ DAYS - case "h" | "hour" | "hours" ⇒ HOURS - case "min" | "minute" | "minutes" ⇒ MINUTES - case "s" | "sec" | "second" | "seconds" ⇒ SECONDS - case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" ⇒ MILLISECONDS - case "µs" | "micro" | "micros" | "microsecond" | "microseconds" ⇒ MICROSECONDS - case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" ⇒ NANOSECONDS + case "d" | "day" | "days" => DAYS + case "h" | "hour" | "hours" => HOURS + case "min" | "minute" | "minutes" => MINUTES + case "s" | "sec" | "second" | "seconds" => SECONDS + case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS + case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS + case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS } val Zero: FiniteDuration = new FiniteDuration(0, NANOSECONDS) @@ -138,26 +138,26 @@ object Duration { } trait Infinite { - this: Duration ⇒ + this: Duration => def +(other: Duration): Duration = other match { - case _: this.type ⇒ this - case _: Infinite ⇒ throw new IllegalArgumentException("illegal addition of infinities") - case _ ⇒ this + case _: this.type => this + case _: Infinite => throw new IllegalArgumentException("illegal addition of infinities") + case _ => this } def -(other: Duration): Duration = other match { - case _: this.type ⇒ throw new IllegalArgumentException("illegal subtraction of infinities") - case _ ⇒ this + case _: this.type => throw new IllegalArgumentException("illegal subtraction of infinities") + case _ => this } def *(factor: Double): Duration = this def /(factor: Double): Duration = this def /(other: Duration): Double = other match { - case _: Infinite ⇒ throw new IllegalArgumentException("illegal division of infinities") + case _: Infinite => throw new IllegalArgumentException("illegal division of infinities") // maybe questionable but pragmatic: Inf / 0 => Inf - case x ⇒ Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1) + case x => Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1) } def finite_? = false @@ -300,20 +300,20 @@ class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration { def toUnit(u: TimeUnit) = toNanos.toDouble / NANOSECONDS.convert(1, u) override def toString = this match { - case Duration(1, DAYS) ⇒ "1 day" - case Duration(x, DAYS) ⇒ x + " days" - case Duration(1, HOURS) ⇒ "1 hour" - case Duration(x, HOURS) ⇒ x + " hours" - case Duration(1, MINUTES) ⇒ "1 minute" - case Duration(x, MINUTES) ⇒ x + " minutes" - case Duration(1, SECONDS) ⇒ "1 second" - case Duration(x, SECONDS) ⇒ x + " seconds" - case Duration(1, MILLISECONDS) ⇒ "1 millisecond" - case Duration(x, MILLISECONDS) ⇒ x + " milliseconds" - case Duration(1, MICROSECONDS) ⇒ "1 microsecond" - case Duration(x, MICROSECONDS) ⇒ x + " microseconds" - case Duration(1, NANOSECONDS) ⇒ "1 nanosecond" - case Duration(x, NANOSECONDS) ⇒ x + " nanoseconds" + case Duration(1, DAYS) => "1 day" + case Duration(x, DAYS) => x + " days" + case Duration(1, HOURS) => "1 hour" + case Duration(x, HOURS) => x + " hours" + case Duration(1, MINUTES) => "1 minute" + case Duration(x, MINUTES) => x + " minutes" + case Duration(1, SECONDS) => "1 second" + case Duration(x, SECONDS) => x + " seconds" + case Duration(1, MILLISECONDS) => "1 millisecond" + case Duration(x, MILLISECONDS) => x + " milliseconds" + case Duration(1, MICROSECONDS) => "1 microsecond" + case Duration(x, MICROSECONDS) => x + " microseconds" + case Duration(1, NANOSECONDS) => "1 nanosecond" + case Duration(x, NANOSECONDS) => x + " nanoseconds" } def compare(other: Duration) = diff --git a/test/files/jvm/future-spec/FutureTests.scala b/test/files/jvm/future-spec/FutureTests.scala index aa308945a1..9155985d7f 100644 --- a/test/files/jvm/future-spec/FutureTests.scala +++ b/test/files/jvm/future-spec/FutureTests.scala @@ -4,6 +4,7 @@ import scala.concurrent._ import scala.concurrent.util.duration._ import scala.concurrent.util.Duration.Inf +import scala.collection._ @@ -317,8 +318,13 @@ object FutureTests extends MinimalScalaTest { Await.result(traversedList, defaultTimeout).sum mustBe (10000) } - /* need configurable execution contexts here "shouldHandleThrowables" in { + val ms = new mutable.HashSet[Throwable] with mutable.SynchronizedSet[Throwable] + implicit val ec = scala.concurrent.ExecutionContext.fromExecutor(new scala.concurrent.forkjoin.ForkJoinPool(), { + t => + ms += t + }) + class ThrowableTest(m: String) extends Throwable(m) val f1 = future[Any] { @@ -347,8 +353,138 @@ object FutureTests extends MinimalScalaTest { f2 onSuccess { case _ => throw new ThrowableTest("current thread receive") } Await.result(f3, defaultTimeout) mustBe ("SUCCESS") + + val waiting = future { + Thread.sleep(1000) + } + Await.ready(waiting, 2000 millis) + + ms.size mustBe (4) + } + + "shouldBlockUntilResult" in { + val latch = new TestLatch + + val f = future { + Await.ready(latch, 5 seconds) + 5 + } + val f2 = future { + val res = Await.result(f, Inf) + res + 9 + } + + intercept[TimeoutException] { + Await.ready(f2, 100 millis) + } + + latch.open() + + Await.result(f2, defaultTimeout) mustBe (14) + + val f3 = future { + Thread.sleep(100) + 5 + } + + intercept[TimeoutException] { + Await.ready(f3, 0 millis) + } } - */ + + "run callbacks async" in { + val latch = Vector.fill(10)(new TestLatch) + + val f1 = future { + latch(0).open() + Await.ready(latch(1), TestLatch.DefaultTimeout) + "Hello" + } + val f2 = f1 map { + s => + latch(2).open() + Await.ready(latch(3), TestLatch.DefaultTimeout) + s.length + } + for (_ <- f2) latch(4).open() + + Await.ready(latch(0), TestLatch.DefaultTimeout) + + f1.isCompleted mustBe (false) + f2.isCompleted mustBe (false) + + latch(1).open() + Await.ready(latch(2), TestLatch.DefaultTimeout) + + f1.isCompleted mustBe (true) + f2.isCompleted mustBe (false) + + val f3 = f1 map { + s => + latch(5).open() + Await.ready(latch(6), TestLatch.DefaultTimeout) + s.length * 2 + } + for (_ <- f3) latch(3).open() + + Await.ready(latch(5), TestLatch.DefaultTimeout) + + f3.isCompleted mustBe (false) + + latch(6).open() + Await.ready(latch(4), TestLatch.DefaultTimeout) + + f2.isCompleted mustBe (true) + f3.isCompleted mustBe (true) + + val p1 = Promise[String]() + val f4 = p1.future map { + s => + latch(7).open() + Await.ready(latch(8), TestLatch.DefaultTimeout) + s.length + } + for (_ <- f4) latch(9).open() + + p1.future.isCompleted mustBe (false) + f4.isCompleted mustBe (false) + + p1 complete Right("Hello") + + Await.ready(latch(7), TestLatch.DefaultTimeout) + + p1.future.isCompleted mustBe (true) + f4.isCompleted mustBe (false) + + latch(8).open() + Await.ready(latch(9), TestLatch.DefaultTimeout) + + Await.ready(f4, defaultTimeout).isCompleted mustBe (true) + } + + "should not deadlock with nested await (ticket 1313)" in { + val simple = Future() map { + _ => + val unit = Future(()) + val umap = unit map { _ => () } + Await.result(umap, Inf) + } + Await.ready(simple, Inf).isCompleted mustBe (true) + + val l1, l2 = new TestLatch + val complex = Future() map { + _ => + blocking { + val nested = Future(()) + for (_ <- nested) l1.open() + Await.ready(l1, TestLatch.DefaultTimeout) // make sure nested is completed + for (_ <- nested) l2.open() + Await.ready(l2, TestLatch.DefaultTimeout) + } + } + Await.ready(complex, defaultTimeout).isCompleted mustBe (true) + } + } } -- cgit v1.2.3