diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2016-04-19 15:16:02 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-19 15:16:02 -0700 |
commit | da8859226e09aa6ebcf6a1c5c1369dec3c216eac (patch) | |
tree | a72601d6d067bf81e5531e4de7d93f226186aef5 /sql/core/src | |
parent | 3c91afec20607e0d853433a904105ee22df73c73 (diff) | |
download | spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.gz spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.bz2 spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.zip |
[SPARK-4226] [SQL] Support IN/EXISTS Subqueries
### What changes were proposed in this pull request?
This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms:
- `[NOT] EXISTS(subquery)`
- `[NOT] IN (subquery)`
This PR is (loosely) based on the work of davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/9055). They should be credited for the work they did.
### How was this patch tested?
Modified parsing unit tests.
Added tests to `org.apache.spark.sql.SQLQuerySuite`
cc rxin, davies & chenghao-intel
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #12306 from hvanhovell/SPARK-4226.
Diffstat (limited to 'sql/core/src')
3 files changed, 144 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index b3e8b37a2e..71b6a97852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -42,6 +43,7 @@ case class ScalarSubquery( override def plan: SparkPlan = Subquery(simpleString, executedPlan) override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def children: Seq[Expression] = Nil override def nullable: Boolean = true override def toString: String = s"subquery#${exprId.id}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 2dca792c83..cbacb5e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Locale, TimeZone} +import java.util.{ArrayDeque, Locale, TimeZone} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.types.ObjectType + + abstract class QueryTest extends PlanTest { protected def sqlContext: SQLContext @@ -47,6 +49,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer + * * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output @@ -119,6 +122,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -158,6 +162,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param dataFrame the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param absTol the absolute tolerance between actual and expected answers. @@ -198,7 +203,10 @@ abstract class QueryTest extends PlanTest { } private def checkJsonFormat(df: DataFrame): Unit = { + // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that + // RDD and Data resolution does not break. val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. logicalPlan.transform { case _: ObjectConsumer => return @@ -236,9 +244,27 @@ abstract class QueryTest extends PlanTest { // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains // these non-serializable stuff, and use these original ones to replace the null-placeholders // in the logical plans parsed from JSON. - var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } - var localRelations = logicalPlan.collect { case l: LocalRelation => l } - var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + val logicalRDDs = new ArrayDeque[LogicalRDD]() + val localRelations = new ArrayDeque[LocalRelation]() + val inMemoryRelations = new ArrayDeque[InMemoryRelation]() + def collectData: (LogicalPlan => Unit) = { + case l: LogicalRDD => + logicalRDDs.offer(l) + case l: LocalRelation => + localRelations.offer(l) + case i: InMemoryRelation => + inMemoryRelations.offer(i) + case p => + p.expressions.foreach { + _.foreach { + case s: SubqueryExpression => + s.query.foreach(collectData) + case _ => + } + } + } + logicalPlan.foreach(collectData) + val jsonBackPlan = try { TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) @@ -253,18 +279,15 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - val normalized2 = jsonBackPlan transformDown { + def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => - val origin = logicalRDDs.head - logicalRDDs = logicalRDDs.drop(1) + val origin = logicalRDDs.pop() LogicalRDD(l.output, origin.rdd)(sqlContext) case l: LocalRelation => - val origin = localRelations.head - localRelations = localRelations.drop(1) + val origin = localRelations.pop() l.copy(data = origin.data) case l: InMemoryRelation => - val origin = inMemoryRelations.head - inMemoryRelations = inMemoryRelations.drop(1) + val origin = inMemoryRelations.pop() InMemoryRelation( l.output, l.useCompression, @@ -275,7 +298,13 @@ abstract class QueryTest extends PlanTest { origin.cachedColumnBuffers, l._statistics, origin._batchStats) + case p => + p.transformExpressions { + case s: SubqueryExpression => + s.withNewPlan(s.query.transformDown(renormalize)) + } } + val normalized2 = jsonBackPlan.transformDown(renormalize) assert(logicalRDDs.isEmpty) assert(localRelations.isEmpty) @@ -309,6 +338,7 @@ object QueryTest { * If there was exception during the execution or the contents of the DataFrame does not * match the expected result, an error message will be returned. Otherwise, a [[None]] will * be returned. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -383,6 +413,7 @@ object QueryTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param actualAnswer the actual result in a [[Row]]. * @param expectedAnswer the expected result in a[[Row]]. * @param absTol the absolute tolerance between actual and expected answers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 21b19fe7df..5742983fb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,6 +22,38 @@ import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ + setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Double)](_) + + lazy val l = Seq( + row(1, 2.0), + row(1, 2.0), + row(2, 1.0), + row(2, 1.0), + row(3, 3.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("a", "b") + + lazy val r = Seq( + row(2, 3.0), + row(2, 3.0), + row(3, 2.0), + row(4, 1.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("c", "d") + + lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) + + protected override def beforeAll(): Unit = { + super.beforeAll() + l.registerTempTable("l") + r.registerTempTable("r") + t.registerTempTable("t") + } + test("simple uncorrelated scalar subquery") { assertResult(Array(Row(1))) { sql("select (select 1 as b) as b").collect() @@ -80,4 +112,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext { " where key = (select max(key) from subqueryData) - 1)").collect() } } + + test("EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + } + + test("NOT EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil) + + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + + test("IN predicate subquery") { + checkAnswer( + sql("select * from l where l.a in (select c from r)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r where l.b < r.d)"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"), + Row(3, 3.0) :: Nil) + } + + test("NOT IN predicate subquery") { + checkAnswer( + sql("select * from l where a not in(select c from r)"), + Nil) + + checkAnswer( + sql("select * from l where a not in(select c from r where c is not null)"), + Row(1, 2.0) :: Row(1, 2.0) :: Nil) + + checkAnswer( + sql("select * from l where a not in(select c from t where b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil) + + // Empty sub-query + checkAnswer( + sql("select * from l where a not in(select c from r where c > 10 and b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: + Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + + } + + test("complex IN predicate subquery") { + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from r)"), + Nil) + + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil) + } } |