summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksandar Prokopec <axel22@gmail.com>2012-04-30 16:32:26 +0200
committerAleksandar Prokopec <axel22@gmail.com>2012-04-30 16:32:26 +0200
commitb2472a5f057ca5e18dc8b3205600c14bf31dc1d1 (patch)
tree957d6cac253c81806a0baff6c0dfbf32f2dbac31
parent9c4baa93d906b161f501ae04f1552e1b7d448436 (diff)
downloadscala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.tar.gz
scala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.tar.bz2
scala-b2472a5f057ca5e18dc8b3205600c14bf31dc1d1.zip
Fixed a bug with setting execution contexts.
Ported most of the future tests.
-rw-r--r--src/library/scala/concurrent/ConcurrentPackageObject.scala9
-rw-r--r--src/library/scala/concurrent/ExecutionContext.scala11
-rw-r--r--src/library/scala/concurrent/impl/ExecutionContextImpl.scala20
-rw-r--r--src/library/scala/concurrent/impl/Promise.scala10
-rw-r--r--src/library/scala/concurrent/package.scala10
-rw-r--r--src/library/scala/concurrent/util/Duration.scala58
-rw-r--r--test/files/jvm/future-spec/FutureTests.scala140
7 files changed, 201 insertions, 57 deletions
diff --git a/src/library/scala/concurrent/ConcurrentPackageObject.scala b/src/library/scala/concurrent/ConcurrentPackageObject.scala
index fafd7fd238..a3d985733e 100644
--- a/src/library/scala/concurrent/ConcurrentPackageObject.scala
+++ b/src/library/scala/concurrent/ConcurrentPackageObject.scala
@@ -22,7 +22,7 @@ abstract class ConcurrentPackageObject {
*/
lazy val defaultExecutionContext = new impl.ExecutionContextImpl(null)
- private val currentExecutionContext = new ThreadLocal[ExecutionContext]
+ val currentExecutionContext = new ThreadLocal[ExecutionContext]
val handledFutureException: PartialFunction[Throwable, Throwable] = {
case t: Throwable if isFutureThrowable(t) => t
@@ -82,7 +82,7 @@ abstract class ConcurrentPackageObject {
* - TimeoutException - in the case that the blockable object timed out
*/
def blocking[T](body: =>T): T =
- blocking(impl.Future.body2awaitable(body), Duration.fromNanos(0))
+ blocking(impl.Future.body2awaitable(body), Duration.Inf)
/** Blocks on an awaitable object.
*
@@ -93,11 +93,12 @@ abstract class ConcurrentPackageObject {
* - InterruptedException - in the case that a wait within the blockable object was interrupted
* - TimeoutException - in the case that the blockable object timed out
*/
- def blocking[T](awaitable: Awaitable[T], atMost: Duration): T =
+ def blocking[T](awaitable: Awaitable[T], atMost: Duration): T = {
currentExecutionContext.get match {
- case null => Await.result(awaitable, atMost)
+ case null => awaitable.result(atMost)(Await.canAwaitEvidence)
case ec => ec.internalBlockingCall(awaitable, atMost)
}
+ }
@inline implicit final def int2durationops(x: Int): DurationOps = new DurationOps(x)
}
diff --git a/src/library/scala/concurrent/ExecutionContext.scala b/src/library/scala/concurrent/ExecutionContext.scala
index 3f62f58bf8..4666674b5b 100644
--- a/src/library/scala/concurrent/ExecutionContext.scala
+++ b/src/library/scala/concurrent/ExecutionContext.scala
@@ -47,11 +47,18 @@ object ExecutionContext {
/** Creates an `ExecutionContext` from the given `ExecutorService`.
*/
- def fromExecutorService(e: ExecutorService): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
+ def fromExecutorService(e: ExecutorService, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
/** Creates an `ExecutionContext` from the given `Executor`.
*/
- def fromExecutor(e: Executor): ExecutionContext with Executor = new impl.ExecutionContextImpl(e)
+ def fromExecutor(e: Executor, reporter: Throwable => Unit = defaultReporter): ExecutionContext with Executor = new impl.ExecutionContextImpl(e, reporter)
+
+ def defaultReporter: Throwable => Unit = {
+ // `Error`s are currently wrapped by `resolver`.
+ // Also, re-throwing `Error`s here causes an exception handling test to fail.
+ //case e: Error => throw e
+ case t => t.printStackTrace()
+ }
}
diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
index ad98331241..7dc3ed2988 100644
--- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
+++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
@@ -17,7 +17,8 @@ import scala.concurrent.util.{ Duration }
-private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext with Executor {
+private[scala] class ExecutionContextImpl(es: AnyRef, reporter: Throwable => Unit = ExecutionContext.defaultReporter)
+extends ExecutionContext with Executor {
import ExecutionContextImpl._
val executorService: AnyRef = if (es eq null) getExecutorService else es
@@ -26,7 +27,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
def executorsThreadFactory = new ThreadFactory {
def newThread(r: Runnable) = new Thread(new Runnable {
override def run() {
- currentExecutionContext.set(ExecutionContextImpl.this)
+ scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
r.run()
}
})
@@ -36,7 +37,7 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
def forkJoinPoolThreadFactory = new ForkJoinPool.ForkJoinWorkerThreadFactory {
def newThread(fjp: ForkJoinPool) = new ForkJoinWorkerThread(fjp) {
override def onStart() {
- currentExecutionContext.set(ExecutionContextImpl.this)
+ scala.concurrent.currentExecutionContext.set(ExecutionContextImpl.this)
}
}
}
@@ -92,22 +93,13 @@ private[scala] class ExecutionContextImpl(es: AnyRef) extends ExecutionContext w
}
}
- def reportFailure(t: Throwable) = t match {
- // `Error`s are currently wrapped by `resolver`.
- // Also, re-throwing `Error`s here causes an exception handling test to fail.
- //case e: Error => throw e
- case t => t.printStackTrace()
- }
+ def reportFailure(t: Throwable) = reporter(t)
}
private[concurrent] object ExecutionContextImpl {
-
- private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] {
- override protected def initialValue = null
- }
-
+
}
diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala
index ee1841aaff..e8d569fdba 100644
--- a/src/library/scala/concurrent/impl/Promise.scala
+++ b/src/library/scala/concurrent/impl/Promise.scala
@@ -78,7 +78,7 @@ object Promise {
*/
class DefaultPromise[T](implicit val executor: ExecutionContext) extends AbstractPromise with Promise[T] { self =>
updater.set(this, Nil) // Start at "No callbacks" //FIXME switch to Unsafe instead of ARFU
-
+
protected final def tryAwait(atMost: Duration): Boolean = {
@tailrec
def awaitUnsafe(waitTimeNanos: Long): Boolean = {
@@ -88,7 +88,7 @@ object Promise {
val start = System.nanoTime()
try {
synchronized {
- while (!isCompleted) wait(ms, ns)
+ if (!isCompleted) wait(ms, ns) // previously - this was a `while`, ending up in an infinite loop
}
} catch {
case e: InterruptedException =>
@@ -99,7 +99,7 @@ object Promise {
isCompleted
}
//FIXME do not do this if there'll be no waiting
- blocking(Future.body2awaitable(awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)), atMost)
+ awaitUnsafe(if (atMost.isFinite) atMost.toNanos else Long.MaxValue)
}
@throws(classOf[TimeoutException])
@@ -147,7 +147,9 @@ object Promise {
}
tryComplete(resolveEither(value))
} finally {
- synchronized { notifyAll() } //Notify any evil blockers
+ synchronized { //Notify any evil blockers
+ notifyAll()
+ }
}
}
diff --git a/src/library/scala/concurrent/package.scala b/src/library/scala/concurrent/package.scala
index b06c6f3c63..169826034b 100644
--- a/src/library/scala/concurrent/package.scala
+++ b/src/library/scala/concurrent/package.scala
@@ -26,9 +26,15 @@ package concurrent {
object Await {
private[concurrent] implicit val canAwaitEvidence = new CanAwait {}
- def ready[T](awaitable: Awaitable[T], atMost: Duration): Awaitable[T] = awaitable.ready(atMost)
+ def ready[T <: Awaitable[_]](awaitable: T, atMost: Duration): T = {
+ blocking(awaitable, atMost)
+ awaitable
+ }
+
+ def result[T](awaitable: Awaitable[T], atMost: Duration): T = {
+ blocking(awaitable, atMost)
+ }
- def result[T](awaitable: Awaitable[T], atMost: Duration): T = awaitable.result(atMost)
}
final class DurationOps private[concurrent] (x: Int) {
diff --git a/src/library/scala/concurrent/util/Duration.scala b/src/library/scala/concurrent/util/Duration.scala
index c4e5fa491a..0e39fdffa3 100644
--- a/src/library/scala/concurrent/util/Duration.scala
+++ b/src/library/scala/concurrent/util/Duration.scala
@@ -115,13 +115,13 @@ object Duration {
* Parse TimeUnit from string representation.
*/
protected[util] def timeUnit(unit: String): TimeUnit = unit.toLowerCase match {
- case "d" | "day" | "days" ⇒ DAYS
- case "h" | "hour" | "hours" ⇒ HOURS
- case "min" | "minute" | "minutes" ⇒ MINUTES
- case "s" | "sec" | "second" | "seconds" ⇒ SECONDS
- case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" ⇒ MILLISECONDS
- case "µs" | "micro" | "micros" | "microsecond" | "microseconds" ⇒ MICROSECONDS
- case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" ⇒ NANOSECONDS
+ case "d" | "day" | "days" => DAYS
+ case "h" | "hour" | "hours" => HOURS
+ case "min" | "minute" | "minutes" => MINUTES
+ case "s" | "sec" | "second" | "seconds" => SECONDS
+ case "ms" | "milli" | "millis" | "millisecond" | "milliseconds" => MILLISECONDS
+ case "µs" | "micro" | "micros" | "microsecond" | "microseconds" => MICROSECONDS
+ case "ns" | "nano" | "nanos" | "nanosecond" | "nanoseconds" => NANOSECONDS
}
val Zero: FiniteDuration = new FiniteDuration(0, NANOSECONDS)
@@ -138,26 +138,26 @@ object Duration {
}
trait Infinite {
- this: Duration ⇒
+ this: Duration =>
def +(other: Duration): Duration =
other match {
- case _: this.type ⇒ this
- case _: Infinite ⇒ throw new IllegalArgumentException("illegal addition of infinities")
- case _ ⇒ this
+ case _: this.type => this
+ case _: Infinite => throw new IllegalArgumentException("illegal addition of infinities")
+ case _ => this
}
def -(other: Duration): Duration =
other match {
- case _: this.type ⇒ throw new IllegalArgumentException("illegal subtraction of infinities")
- case _ ⇒ this
+ case _: this.type => throw new IllegalArgumentException("illegal subtraction of infinities")
+ case _ => this
}
def *(factor: Double): Duration = this
def /(factor: Double): Duration = this
def /(other: Duration): Double =
other match {
- case _: Infinite ⇒ throw new IllegalArgumentException("illegal division of infinities")
+ case _: Infinite => throw new IllegalArgumentException("illegal division of infinities")
// maybe questionable but pragmatic: Inf / 0 => Inf
- case x ⇒ Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
+ case x => Double.PositiveInfinity * (if ((this > Zero) ^ (other >= Zero)) -1 else 1)
}
def finite_? = false
@@ -300,20 +300,20 @@ class FiniteDuration(val length: Long, val unit: TimeUnit) extends Duration {
def toUnit(u: TimeUnit) = toNanos.toDouble / NANOSECONDS.convert(1, u)
override def toString = this match {
- case Duration(1, DAYS) ⇒ "1 day"
- case Duration(x, DAYS) ⇒ x + " days"
- case Duration(1, HOURS) ⇒ "1 hour"
- case Duration(x, HOURS) ⇒ x + " hours"
- case Duration(1, MINUTES) ⇒ "1 minute"
- case Duration(x, MINUTES) ⇒ x + " minutes"
- case Duration(1, SECONDS) ⇒ "1 second"
- case Duration(x, SECONDS) ⇒ x + " seconds"
- case Duration(1, MILLISECONDS) ⇒ "1 millisecond"
- case Duration(x, MILLISECONDS) ⇒ x + " milliseconds"
- case Duration(1, MICROSECONDS) ⇒ "1 microsecond"
- case Duration(x, MICROSECONDS) ⇒ x + " microseconds"
- case Duration(1, NANOSECONDS) ⇒ "1 nanosecond"
- case Duration(x, NANOSECONDS) ⇒ x + " nanoseconds"
+ case Duration(1, DAYS) => "1 day"
+ case Duration(x, DAYS) => x + " days"
+ case Duration(1, HOURS) => "1 hour"
+ case Duration(x, HOURS) => x + " hours"
+ case Duration(1, MINUTES) => "1 minute"
+ case Duration(x, MINUTES) => x + " minutes"
+ case Duration(1, SECONDS) => "1 second"
+ case Duration(x, SECONDS) => x + " seconds"
+ case Duration(1, MILLISECONDS) => "1 millisecond"
+ case Duration(x, MILLISECONDS) => x + " milliseconds"
+ case Duration(1, MICROSECONDS) => "1 microsecond"
+ case Duration(x, MICROSECONDS) => x + " microseconds"
+ case Duration(1, NANOSECONDS) => "1 nanosecond"
+ case Duration(x, NANOSECONDS) => x + " nanoseconds"
}
def compare(other: Duration) =
diff --git a/test/files/jvm/future-spec/FutureTests.scala b/test/files/jvm/future-spec/FutureTests.scala
index aa308945a1..9155985d7f 100644
--- a/test/files/jvm/future-spec/FutureTests.scala
+++ b/test/files/jvm/future-spec/FutureTests.scala
@@ -4,6 +4,7 @@
import scala.concurrent._
import scala.concurrent.util.duration._
import scala.concurrent.util.Duration.Inf
+import scala.collection._
@@ -317,8 +318,13 @@ object FutureTests extends MinimalScalaTest {
Await.result(traversedList, defaultTimeout).sum mustBe (10000)
}
- /* need configurable execution contexts here
"shouldHandleThrowables" in {
+ val ms = new mutable.HashSet[Throwable] with mutable.SynchronizedSet[Throwable]
+ implicit val ec = scala.concurrent.ExecutionContext.fromExecutor(new scala.concurrent.forkjoin.ForkJoinPool(), {
+ t =>
+ ms += t
+ })
+
class ThrowableTest(m: String) extends Throwable(m)
val f1 = future[Any] {
@@ -347,8 +353,138 @@ object FutureTests extends MinimalScalaTest {
f2 onSuccess { case _ => throw new ThrowableTest("current thread receive") }
Await.result(f3, defaultTimeout) mustBe ("SUCCESS")
+
+ val waiting = future {
+ Thread.sleep(1000)
+ }
+ Await.ready(waiting, 2000 millis)
+
+ ms.size mustBe (4)
+ }
+
+ "shouldBlockUntilResult" in {
+ val latch = new TestLatch
+
+ val f = future {
+ Await.ready(latch, 5 seconds)
+ 5
+ }
+ val f2 = future {
+ val res = Await.result(f, Inf)
+ res + 9
+ }
+
+ intercept[TimeoutException] {
+ Await.ready(f2, 100 millis)
+ }
+
+ latch.open()
+
+ Await.result(f2, defaultTimeout) mustBe (14)
+
+ val f3 = future {
+ Thread.sleep(100)
+ 5
+ }
+
+ intercept[TimeoutException] {
+ Await.ready(f3, 0 millis)
+ }
}
- */
+
+ "run callbacks async" in {
+ val latch = Vector.fill(10)(new TestLatch)
+
+ val f1 = future {
+ latch(0).open()
+ Await.ready(latch(1), TestLatch.DefaultTimeout)
+ "Hello"
+ }
+ val f2 = f1 map {
+ s =>
+ latch(2).open()
+ Await.ready(latch(3), TestLatch.DefaultTimeout)
+ s.length
+ }
+ for (_ <- f2) latch(4).open()
+
+ Await.ready(latch(0), TestLatch.DefaultTimeout)
+
+ f1.isCompleted mustBe (false)
+ f2.isCompleted mustBe (false)
+
+ latch(1).open()
+ Await.ready(latch(2), TestLatch.DefaultTimeout)
+
+ f1.isCompleted mustBe (true)
+ f2.isCompleted mustBe (false)
+
+ val f3 = f1 map {
+ s =>
+ latch(5).open()
+ Await.ready(latch(6), TestLatch.DefaultTimeout)
+ s.length * 2
+ }
+ for (_ <- f3) latch(3).open()
+
+ Await.ready(latch(5), TestLatch.DefaultTimeout)
+
+ f3.isCompleted mustBe (false)
+
+ latch(6).open()
+ Await.ready(latch(4), TestLatch.DefaultTimeout)
+
+ f2.isCompleted mustBe (true)
+ f3.isCompleted mustBe (true)
+
+ val p1 = Promise[String]()
+ val f4 = p1.future map {
+ s =>
+ latch(7).open()
+ Await.ready(latch(8), TestLatch.DefaultTimeout)
+ s.length
+ }
+ for (_ <- f4) latch(9).open()
+
+ p1.future.isCompleted mustBe (false)
+ f4.isCompleted mustBe (false)
+
+ p1 complete Right("Hello")
+
+ Await.ready(latch(7), TestLatch.DefaultTimeout)
+
+ p1.future.isCompleted mustBe (true)
+ f4.isCompleted mustBe (false)
+
+ latch(8).open()
+ Await.ready(latch(9), TestLatch.DefaultTimeout)
+
+ Await.ready(f4, defaultTimeout).isCompleted mustBe (true)
+ }
+
+ "should not deadlock with nested await (ticket 1313)" in {
+ val simple = Future() map {
+ _ =>
+ val unit = Future(())
+ val umap = unit map { _ => () }
+ Await.result(umap, Inf)
+ }
+ Await.ready(simple, Inf).isCompleted mustBe (true)
+
+ val l1, l2 = new TestLatch
+ val complex = Future() map {
+ _ =>
+ blocking {
+ val nested = Future(())
+ for (_ <- nested) l1.open()
+ Await.ready(l1, TestLatch.DefaultTimeout) // make sure nested is completed
+ for (_ <- nested) l2.open()
+ Await.ready(l2, TestLatch.DefaultTimeout)
+ }
+ }
+ Await.ready(complex, defaultTimeout).isCompleted mustBe (true)
+ }
+
}
}