diff options
author | Reynold Xin <rxin@databricks.com> | 2016-01-09 11:21:58 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-09 11:21:58 -0800 |
commit | b23c4521f5df905e4fe4d79dd5b670286e2697f7 (patch) | |
tree | 5191e9fa906fa06aaa2064edaf91aeacffa72128 /core | |
parent | 3d77cffec093bed4d330969f1a996f3358b9a772 (diff) | |
download | spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.gz spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.bz2 spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.zip |
[SPARK-12340] Fix overflow in various take functions.
This is a follow-up for the original patch #10562.
Author: Reynold Xin <rxin@databricks.com>
Closes #10670 from rxin/SPARK-12340.
Diffstat (limited to 'core')
3 files changed, 10 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 94719a4572..7de9df1e48 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - job.flatMap {_ => + job.flatMap { _ => buf.foreach(results ++= _.take(num - results.size)) continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0L)(_)) + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e25657cc10..de7102f5b6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0L + var partsScanned = 0 while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. @@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 24acbed4d7..ef2ed44500 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.take(501) === (1 to 501).toArray) assert(nums.take(999) === (1 to 999).toArray) assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.parallelize(1 to 2, 2) + assert(nums.take(2147483638).size === 2) + assert(nums.takeAsync(2147483638).get.size === 2) } test("top with predefined ordering") { |