aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-09 11:21:58 -0800
committerReynold Xin <rxin@databricks.com>2016-01-09 11:21:58 -0800
commitb23c4521f5df905e4fe4d79dd5b670286e2697f7 (patch)
tree5191e9fa906fa06aaa2064edaf91aeacffa72128 /sql
parent3d77cffec093bed4d330969f1a996f3358b9a772 (diff)
downloadspark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.gz
spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.tar.bz2
spark-b23c4521f5df905e4fe4d79dd5b670286e2697f7.zip
[SPARK-12340] Fix overflow in various take functions.
This is a follow-up for the original patch #10562. Author: Reynold Xin <rxin@databricks.com> Closes #10670 from rxin/SPARK-12340.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
3 files changed, 9 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21a6fba907..2355de3d05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
- var partsScanned = 0L
+ var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
@@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = n - buf.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val sc = sqlContext.sparkContext
- val res =
- sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
+ val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(n - buf.size))
partsScanned += p.size
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ade1391ecd..983dfbdede 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
+
+ // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
+ checkAnswer(
+ sqlContext.range(2).limit(2147483638),
+ Row(0) :: Row(1) :: Nil
+ )
}
test("except") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bd987ae1bb..5de0979606 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}
-
- test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
- val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
- rdd.toDF("key").registerTempTable("spark12340")
- checkAnswer(
- sql("select key from spark12340 limit 2147483638"),
- Row(1) :: Row(2) :: Row(3) :: Nil
- )
- assert(rdd.take(2147483638).size === 3)
- assert(rdd.takeAsync(2147483638).get.size === 3)
- }
-
}