aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-09 11:21:58 -0800
committerReynold Xin <rxin@databricks.com>2016-01-09 11:21:58 -0800
commitb23c4521f5df905e4fe4d79dd5b670286e2697f7 (patch)
tree5191e9fa906fa06aaa2064edaf91aeacffa72128 /core
parent3d77cffec093bed4d330969f1a996f3358b9a772 (diff)
downloadspark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.gz
spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.bz2
spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.zip
[SPARK-12340] Fix overflow in various take functions.
This is a follow-up for the original patch #10562. Author: Reynold Xin <rxin@databricks.com> Closes #10670 from rxin/SPARK-12340.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala4
3 files changed, 10 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 94719a4572..7de9df1e48 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
- def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
+ def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
@@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}
val left = num - results.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
@@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
- job.flatMap {_ =>
+ job.flatMap { _ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + p.size)
}
}
- new ComplexFutureAction[Seq[T]](continue(0L)(_))
+ new ComplexFutureAction[Seq[T]](continue(0)(_))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e25657cc10..de7102f5b6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
- var partsScanned = 0L
+ var partsScanned = 0
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
@@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag](
}
val left = num - buf.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 24acbed4d7..ef2ed44500 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.parallelize(1 to 2, 2)
+ assert(nums.take(2147483638).size === 2)
+ assert(nums.takeAsync(2147483638).get.size === 2)
}
test("top with predefined ordering") {