aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQiangCai <david.caiq@gmail.com>2016-01-06 18:13:07 +0900
committerKousuke Saruta <sarutak@oss.nttdata.co.jp>2016-01-06 18:13:07 +0900
commit5d871ea43efdde59e05896a50d57021388412d30 (patch)
treee548f49a535d9c4aca5682dec81494420a4780ab
parentb2467b381096804b862990d9ecda554f67e07ee1 (diff)
downloadspark-5d871ea43efdde59e05896a50d57021388412d30.tar.gz
spark-5d871ea43efdde59e05896a50d57021388412d30.tar.bz2
spark-5d871ea43efdde59e05896a50d57021388412d30.zip
[SPARK-12340][SQL] fix Int overflow in the SparkPlan.executeTake, RDD.take and AsyncRDDActions.takeAsync
I have closed pull request https://github.com/apache/spark/pull/10487. And I create this pull request to resolve the problem. spark jira https://issues.apache.org/jira/browse/SPARK-12340 Author: QiangCai <david.caiq@gmail.com> Closes #10562 from QiangCai/bugfix.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
4 files changed, 26 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index ec48925823..94719a4572 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val localProperties = self.context.getLocalProperties
// Cached thread pool to handle aggregation of subtasks.
implicit val executionContext = AsyncRDDActions.futureExecutionContext
- val results = new ArrayBuffer[T](num)
+ val results = new ArrayBuffer[T]
val totalParts = self.partitions.length
/*
@@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
- def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
+ def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
+ var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
@@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}
val left = num - results.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
@@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
Unit)
job.flatMap {_ =>
buf.foreach(results ++= _.take(num - results.size))
- continue(partsScanned + numPartsToTry)
+ continue(partsScanned + p.size)
}
}
- new ComplexFutureAction[Seq[T]](continue(0)(_))
+ new ComplexFutureAction[Seq[T]](continue(0L)(_))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index d6eac7888d..e25657cc10 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1190,11 +1190,11 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
- var partsScanned = 0
+ var partsScanned = 0L
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
+ var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
@@ -1209,11 +1209,11 @@ abstract class RDD[T: ClassTag](
}
val left = num - buf.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
- partsScanned += numPartsToTry
+ partsScanned += p.size
}
buf.toArray
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index f20f32aace..21a6fba907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
- var partsScanned = 0
+ var partsScanned = 0L
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
- var numPartsToTry = 1
+ var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
@@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = n - buf.size
- val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(n - buf.size))
- partsScanned += numPartsToTry
+ partsScanned += p.size
}
buf.toArray
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5de0979606..bd987ae1bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2067,4 +2067,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}
+
+ test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
+ val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
+ rdd.toDF("key").registerTempTable("spark12340")
+ checkAnswer(
+ sql("select key from spark12340 limit 2147483638"),
+ Row(1) :: Row(2) :: Row(3) :: Nil
+ )
+ assert(rdd.take(2147483638).size === 3)
+ assert(rdd.takeAsync(2147483638).get.size === 3)
+ }
+
}