diff options
author | Aleksandar Prokopec <axel22@gmail.com> | 2012-04-30 16:32:26 +0200 |
---|---|---|
committer | Aleksandar Prokopec <axel22@gmail.com> | 2012-04-30 16:32:26 +0200 |
commit | b2472a5f057ca5e18dc8b3205600c14bf31dc1d1 (patch) | |
tree | 957d6cac253c81806a0baff6c0dfbf32f2dbac31 /src | |
parent | 9c4baa93d906b161f501ae04f1552e1b7d448436 (diff) | |
download | scala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.tar.gz scala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.tar.bz2 scala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.zip |
Fixed a bug with setting execution contexts.
Ported most of the future tests.
Diffstat (limited to 'src')
6 files changed, 63 insertions, 55 deletions
diff --git a/src/library/scala/concurrent/ConcurrentPackageObject.scala b/src/library/scala/concurrent/ConcurrentPackageObject.scala index fafd7fd238..a3d985733e 100644 --- a/src/library/scala/concurrent/ConcurrentPackageObject.scala +++ b/src/library/scala/concurrent/ConcurrentPackageObject.scala @@ -22,7 +22,7 @@ abstract class ConcurrentPackageObject { */ lazy val defaultExecutionContext = new impl.ExecutionContextImpl(null) - private val currentExecutionContext = new ThreadLocal[ExecutionContext] + val currentExecutionContext = new ThreadLocal[ExecutionContext] val handledFutureException: PartialFunction[Throwable, Throwable] = { case t: Throwable if isFutureThrowable(t) => t @@ -82,7 +82,7 @@ abstract class ConcurrentPackageObject { * - TimeoutException - in the case that the blockable object timed out */ def blocking[T](body: =>T): T = - blocking(impl.Future.body2awaitable(body), Duration.fromNanos(0)) + blocking(impl.Future.body2awaitable(body), Duration.Inf) /** Blocks on an awaitable object. * @@ -93,11 +93,12 @@ abstract class ConcurrentPackageObject { * - InterruptedException - in the case that a wait within the blockable object was interrupted * - TimeoutException - in the case that the blockable object timed out */ - def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = + def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = { currentExecutionContext.get match { - case null => Await.result(awaitable, atMost) + case null => awaitable.result(atMost)(Await.canAwaitEvidence) case ec => ec.internalBlockingCall(awaitable, atMost) } + } @inline implicit final def int2durationops(x: Int): DurationOps = new DurationOps(x) } diff --git a/src/library/scala/concurrent/ExecutionContext.scala b/src/library/scala/concurrent/ExecutionContext.scala index 3f62f58bf8..4666674b5b 100644 --- a/src/library/scala/concurrent/ExecutionContext.scala +++ b/src/library/scala/concurrent/ExecutionContext.scala @@ -47,11 +47,18 @@ object ExecutionContext { /** Creates an `ExecutionContext` from the given `ExecutorService`. */ - def fromExecutorService(e: ExecutorService): ExecutionContext with Executor = new impl.ExecutionContextImpl(e) + def fromExecutorService(e: ExecutorService, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter) /** Creates an `ExecutionContext` from the given `Executor`. */ - def fromExecutor(e: Executor): ExecutionContext with Executor = new impl.ExecutionContextImpl(e) + def fromExecutor(e: Executor, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter) + + def defaultReporter: Throwable => Unit = { + // `Error`s are currently wrapped by `resolver`. + // Also, re-throwing `Error`s here causes an exception handling test to fail. + //case e: Error => throw e + case t => t.printStackTrace() + } } diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala index ad98331241..7dc3ed2988 100644 --- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala +++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala @@ -17,7 +17,8 @@ import scala.concurrent.util.{ Duration } -private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext with Executor { +private[scala] class ExecutionContextImpl(es: AnyRef, reporter: Throwable => Unit = ExecutionContext.defaultReporter) +extends ExecutionContext with Executor { import ExecutionContextImpl._ val executorService: AnyRef = if (es eq null) getExecutorService else es @@ -26,7 +27,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w def executorsThreadFactory = new ThreadFactory { def newThread(r: Runnable) = new Thread(new Runnable { override def run() { - currentExecutionContext.set(ExecutionContextImpl.this) + scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this) r.run() } }) @@ -36,7 +37,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w def forkJoinPoolThreadFactory = new ForkJoinPool.ForkJoinWorkerThreadFactory { def newThread(fjp: ForkJoinPool) = new ForkJoinWorkerThread(fjp) { override def onStart() { - currentExecutionContext.set(ExecutionContextImpl.this) + scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this) } } } @@ -92,22 +93,13 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w } } - def reportFailure(t: Throwable) = t match { - // `Error`s are currently wrapped by `resolver`. - // Also, re-throwing `Error`s here causes an exception handling test to fail. - //case e: Error => throw e - case t => t.printStackTrace() - } + def reportFailure(t: Throwable) = reporter(t) } private[concurrent] object ExecutionContextImpl { - - private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] { - override protected def initialValue = null - } - + } diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala index ee1841aaff..e8d569fdba 100644 --- a/src/library/scala/concurrent/impl/Promise.scala +++ b/src/library/scala/concurrent/impl/Promise.scala @@ -78,7 +78,7 @@ object Promise { */ class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self => updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU - + protected final def tryAwait(atMost: Duration): Boolean = { @tailrec def awaitUnsafe(waitTimeNanos: Long): Boolean = { @@ -88,7 +88,7 @@ object Promise { val start = System.nanoTime() try { synchronized { - while (!isCompleted) wait(ms, ns) + if (!isCompleted) wait(ms, ns) // previously - this was a `while`, ending up in an infinite loop } } catch { case e: InterruptedException => @@ -99,7 +99,7 @@ object Promise { isCompleted } //FIXME do not do this if there'll be no waiting - blocking(Future.body2awaitable(awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)), atMost) + awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue) } @throws(classOf[TimeoutException]) @@ -147,7 +147,9 @@ object Promise { } tryComplete(resolveEither(value)) } finally { - synchronized { notifyAll() } //Notify any evil blockers + synchronized { //Notify any evil blockers + notifyAll() + } } } diff --git a/src/library/scala/concurrent/package.scala b/src/library/scala/concurrent/package.scala index b06c6f3c63..169826034b 100644 --- a/src/library/scala/concurrent/package.scala +++ b/src/library/scala/concurrent/package.scala @@ -26,9 +26,15 @@ package concurrent { object Await { private[concurrent] implicit val canAwaitEvidence = new CanAwait {} - def ready[T](awaitable: Awaitable[T], atMost: Duration): Awaitable[T] = awaitable.ready(atMost) + def ready[T <: Awaitable[_]](awaitable: T, atMost: Duration): T = { + blocking(awaitable, atMost) + awaitable + } + + def result[T](awaitable: Awaitable[T], atMost: Duration): T = { + blocking(awaitable, atMost) + } - def result[T](awaitable: Awaitable[T], atMost: Duration): T = awaitable.result(atMost) } final class DurationOps private[concurrent] (x: Int) { diff --git a/src/library/scala/concurrent/util/Duration.scala b/src/library/scala/concurrent/util/Duration.scala index c4e5fa491a..0e39fdffa3 100644 --- a/src/library/scala/concurrent/util/Duration.scala +++ b/src/library/scala/concurrent/util/Duration.scala @@ -115,13 +115,13 @@ object Duration { * Parse TimeUnit from string representation. */ protected[util] def timeUnit(unit: String): TimeUnit = unit.toLowerCase match { - case "d" | "day" | "days" ⇒ DAYS - case "h" | "hour" | "hours" ⇒ HOURS - case "min" | "minute" | "minutes" ⇒ MINUTES - case "s" | "sec" | "second" | "seconds" ⇒ SECONDS - case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" ⇒ MILLISECONDS - case "µs" | "micro" | "micros" | "microsecond" | "microseconds" ⇒ MICROSECONDS - case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" ⇒ NANOSECONDS + case "d" | "day" | "days" => DAYS + case "h" | "hour" | "hours" => HOURS + case "min" | "minute" | "minutes" => MINUTES + case "s" | "sec" | "second" | "seconds" => SECONDS + case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS + case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS + case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS } val Zero: FiniteDuration = new FiniteDuration(0, NANOSECONDS) @@ -138,26 +138,26 @@ object Duration { } trait Infinite { - this: Duration ⇒ + this: Duration => def +(other: Duration): Duration = other match { - case _: this.type ⇒ this - case _: Infinite ⇒ throw new IllegalArgumentException("illegal addition of infinities") - case _ ⇒ this + case _: this.type => this + case _: Infinite => throw new IllegalArgumentException("illegal addition of infinities") + case _ => this } def -(other: Duration): Duration = other match { - case _: this.type ⇒ throw new IllegalArgumentException("illegal subtraction of infinities") - case _ ⇒ this + case _: this.type => throw new IllegalArgumentException("illegal subtraction of infinities") + case _ => this } def *(factor: Double): Duration = this def /(factor: Double): Duration = this def /(other: Duration): Double = other match { - case _: Infinite ⇒ throw new IllegalArgumentException("illegal division of infinities") + case _: Infinite => throw new IllegalArgumentException("illegal division of infinities") // maybe questionable but pragmatic: Inf / 0 => Inf - case x ⇒ Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1) + case x => Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1) } def finite_? = false @@ -300,20 +300,20 @@ class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration { def toUnit(u: TimeUnit) = toNanos.toDouble / NANOSECONDS.convert(1, u) override def toString = this match { - case Duration(1, DAYS) ⇒ "1 day" - case Duration(x, DAYS) ⇒ x + " days" - case Duration(1, HOURS) ⇒ "1 hour" - case Duration(x, HOURS) ⇒ x + " hours" - case Duration(1, MINUTES) ⇒ "1 minute" - case Duration(x, MINUTES) ⇒ x + " minutes" - case Duration(1, SECONDS) ⇒ "1 second" - case Duration(x, SECONDS) ⇒ x + " seconds" - case Duration(1, MILLISECONDS) ⇒ "1 millisecond" - case Duration(x, MILLISECONDS) ⇒ x + " milliseconds" - case Duration(1, MICROSECONDS) ⇒ "1 microsecond" - case Duration(x, MICROSECONDS) ⇒ x + " microseconds" - case Duration(1, NANOSECONDS) ⇒ "1 nanosecond" - case Duration(x, NANOSECONDS) ⇒ x + " nanoseconds" + case Duration(1, DAYS) => "1 day" + case Duration(x, DAYS) => x + " days" + case Duration(1, HOURS) => "1 hour" + case Duration(x, HOURS) => x + " hours" + case Duration(1, MINUTES) => "1 minute" + case Duration(x, MINUTES) => x + " minutes" + case Duration(1, SECONDS) => "1 second" + case Duration(x, SECONDS) => x + " seconds" + case Duration(1, MILLISECONDS) => "1 millisecond" + case Duration(x, MILLISECONDS) => x + " milliseconds" + case Duration(1, MICROSECONDS) => "1 microsecond" + case Duration(x, MICROSECONDS) => x + " microseconds" + case Duration(1, NANOSECONDS) => "1 nanosecond" + case Duration(x, NANOSECONDS) => x + " nanoseconds" } def compare(other: Duration) = |