aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-06-09 22:28:31 -0700
committerReynold Xin <rxin@databricks.com>2016-06-09 22:28:31 -0700
commit6c5fd977fbcb821a57cb4a13bc3d413a695fbc32 (patch)
treecb9d66080522eddca18968591e3ddcb9567f1753
parent16df133d7f5f3115cd5baa696fa73a4694f9cba9 (diff)
downloadspark-6c5fd977fbcb821a57cb4a13bc3d413a695fbc32.tar.gz
spark-6c5fd977fbcb821a57cb4a13bc3d413a695fbc32.tar.bz2
spark-6c5fd977fbcb821a57cb4a13bc3d413a695fbc32.zip
[SPARK-15791] Fix NPE in ScalarSubquery
## What changes were proposed in this pull request? The fix is pretty simple, just don't make the executedPlan transient in `ScalarSubquery` since it is referenced at execution time. ## How was this patch tested? I verified the fix manually in non-local mode. It's not clear to me why the problem did not manifest in local mode, any suggestions? cc davies Author: Eric Liang <ekl@databricks.com> Closes #13569 from ericl/fix-scalar-npe.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala4
4 files changed, 15 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 4a1f12d685..461d3010ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DataType
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class ScalarSubquery(
- @transient executedPlan: SparkPlan,
+ executedPlan: SparkPlan,
exprId: ExprId)
extends SubqueryExpression {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 9c044f4e8f..acb59d46e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -341,10 +341,16 @@ object QueryTest {
*
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice.
*/
- def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
+ def checkAnswer(
+ df: DataFrame,
+ expectedAnswer: Seq[Row],
+ checkToRDD: Boolean = true): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
-
+ if (checkToRDD) {
+ df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]
+ }
val sparkAnswer = try df.collect().toSeq catch {
case e: Exception =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8284e8d6d8..90465b65bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2118,7 +2118,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
// is correct.
def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
countAcc.setValue(0)
- checkAnswer(df, expectedResult)
+ QueryTest.checkAnswer(
+ df, Seq(expectedResult), checkToRDD = false /* avoid duplicate exec */)
assert(countAcc.value == expectedCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index a932125f3c..05491a4a88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -54,6 +54,10 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
t.createOrReplaceTempView("t")
}
+ test("rdd deserialization does not crash [SPARK-15791]") {
+ sql("select (select 1 as b) as b").rdd.count()
+ }
+
test("simple uncorrelated scalar subquery") {
checkAnswer(
sql("select (select 1 as b) as b"),