aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-10-11 13:04:45 -0700
committerReynold Xin <rxin@apache.org>2013-10-11 13:04:45 -0700
commite2047d3927e0032cc1d6de3fd09d00f96ce837d0 (patch)
treee55fdb2ad60e5a57ad586b4781e8323a271946e6
parent09f7609254a8b70a551e7403bc5378434318b3f4 (diff)
downloadspark-e2047d3927e0032cc1d6de3fd09d00f96ce837d0.tar.gz
spark-e2047d3927e0032cc1d6de3fd09d00f96ce837d0.tar.bz2
spark-e2047d3927e0032cc1d6de3fd09d00f96ce837d0.zip
Making takeAsync and collectAsync deterministic.
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala10
3 files changed, 15 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 9f41912d6c..eab2957632 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -177,10 +177,6 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future {
thread = Thread.currentThread
try {
- if (cancelled) {
- // This action has been cancelled before this thread even started running.
- this.failure(new SparkException("action cancelled"))
- }
this.success(func)
} catch {
case e: Exception => this.failure(e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 579832427e..32af795d4c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -54,9 +54,9 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
* Return a future for retrieving all elements of this RDD.
*/
def collectAsync(): FutureAction[Seq[T]] = {
- val results = new ArrayBuffer[T]
+ val results = new Array[Array[T]](self.partitions.size)
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
- (index, data) => results ++= data, results)
+ (index, data) => results(index) = data, results.flatten.toSeq)
}
/**
@@ -66,10 +66,10 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
val promise = new CancellablePromise[Seq[T]]
promise.run {
- val buf = new ArrayBuffer[T](num)
+ val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
- while (buf.size < num && partsScanned < totalParts) {
+ while (results.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
@@ -77,26 +77,28 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
- if (buf.size == 0) {
+ if (results.size == 0) {
numPartsToTry = totalParts - 1
} else {
- numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
- val left = num - buf.size
+ val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val buf = new Array[Array[T]](p.size)
promise.runJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
- (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size),
+ (index: Int, data: Array[T]) => buf(index) = data,
Unit)
+ buf.foreach(results ++= _.take(num - results.size))
partsScanned += numPartsToTry
}
- buf.toSeq
+ results.toSeq
}
promise.future
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 131e2466ac..3ef000da4a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -53,8 +53,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
test("collectAsync") {
assert(zeroPartRdd.collectAsync().get() === Seq.empty)
- // Note that we sort the collected output because the order is indeterministic.
- val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted
+ val collected = sc.parallelize(1 to 1000, 3).collectAsync().get()
assert(collected === (1 to 1000))
}
@@ -80,10 +79,9 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
test("takeAsync") {
def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
- // Note that we sort the collected output because the order is indeterministic.
- val expected = input.take(num).size
- val saw = rdd.takeAsync(num).get().size
- assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)"
+ val expected = input.take(num)
+ val saw = rdd.takeAsync(num).get()
+ assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)"
.format(rdd.partitions.size, expected, saw))
}
val input = Range(1, 1000)