aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala27
2 files changed, 33 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 11cd84b396..1e10d73a4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1511,10 +1511,15 @@ object DecimalAggregates extends Rule[LogicalPlan] {
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(projectList, LocalRelation(output, data)) =>
+ case Project(projectList, LocalRelation(output, data))
+ if !projectList.exists(hasUnevaluableExpr) =>
val projection = new InterpretedProjection(projectList, output)
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
}
+
+ private def hasUnevaluableExpr(expr: Expression): Boolean = {
+ expr.find(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference]).isDefined
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 4819692733..a932125f3c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -123,6 +123,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-15677: Queries against local relations with scalar subquery in Select list") {
+ withTempTable("t1", "t2") {
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
+ Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ checkAnswer(
+ sql("SELECT (select 1 as col) from t1"),
+ Row(1) :: Row(1) :: Nil)
+
+ checkAnswer(
+ sql("SELECT (select max(c1) from t2) from t1"),
+ Row(2) :: Row(2) :: Nil)
+
+ checkAnswer(
+ sql("SELECT 1 + (select 1 as col) from t1"),
+ Row(2) :: Row(2) :: Nil)
+
+ checkAnswer(
+ sql("SELECT c1, (select max(c1) from t2) + c2 from t1"),
+ Row(1, 3) :: Row(2, 4) :: Nil)
+
+ checkAnswer(
+ sql("SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1"),
+ Row(1, 1) :: Row(2, 2) :: Nil)
+ }
+ }
+
test("SPARK-14791: scalar subquery inside broadcast join") {
val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)")
val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil