aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-04-19 15:16:02 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-19 15:16:02 -0700
commitda8859226e09aa6ebcf6a1c5c1369dec3c216eac (patch)
treea72601d6d067bf81e5531e4de7d93f226186aef5 /sql/core/src
parent3c91afec20607e0d853433a904105ee22df73c73 (diff)
downloadspark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.gz
spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.tar.bz2
spark-da8859226e09aa6ebcf6a1c5c1369dec3c216eac.zip
[SPARK-4226] [SQL] Support IN/EXISTS Subqueries
### What changes were proposed in this pull request? This PR adds support for in/exists predicate subqueries to Spark. Predicate sub-queries are used as a filtering condition in a query (this is the only supported use case). A predicate sub-query comes in two forms: - `[NOT] EXISTS(subquery)` - `[NOT] IN (subquery)` This PR is (loosely) based on the work of davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/9055). They should be credited for the work they did. ### How was this patch tested? Modified parsing unit tests. Added tests to `org.apache.spark.sql.SQLQuerySuite` cc rxin, davies & chenghao-intel Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12306 from hvanhovell/SPARK-4226.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala98
3 files changed, 144 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index b3e8b37a2e..71b6a97852 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -42,6 +43,7 @@ case class ScalarSubquery(
override def plan: SparkPlan = Subquery(simpleString, executedPlan)
override def dataType: DataType = executedPlan.schema.fields.head.dataType
+ override def children: Seq[Expression] = Nil
override def nullable: Boolean = true
override def toString: String = s"subquery#${exprId.id}"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 2dca792c83..cbacb5e103 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.util.{Locale, TimeZone}
+import java.util.{ArrayDeque, Locale, TimeZone}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
@@ -35,6 +35,8 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
import org.apache.spark.sql.types.ObjectType
+
+
abstract class QueryTest extends PlanTest {
protected def sqlContext: SQLContext
@@ -47,6 +49,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
+ *
* @param df the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
@@ -119,6 +122,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
+ *
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
@@ -158,6 +162,7 @@ abstract class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
* @param dataFrame the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param absTol the absolute tolerance between actual and expected answers.
@@ -198,7 +203,10 @@ abstract class QueryTest extends PlanTest {
}
private def checkJsonFormat(df: DataFrame): Unit = {
+ // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
+ // RDD and Data resolution does not break.
val logicalPlan = df.queryExecution.analyzed
+
// bypass some cases that we can't handle currently.
logicalPlan.transform {
case _: ObjectConsumer => return
@@ -236,9 +244,27 @@ abstract class QueryTest extends PlanTest {
// RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
// these non-serializable stuff, and use these original ones to replace the null-placeholders
// in the logical plans parsed from JSON.
- var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
- var localRelations = logicalPlan.collect { case l: LocalRelation => l }
- var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }
+ val logicalRDDs = new ArrayDeque[LogicalRDD]()
+ val localRelations = new ArrayDeque[LocalRelation]()
+ val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
+ def collectData: (LogicalPlan => Unit) = {
+ case l: LogicalRDD =>
+ logicalRDDs.offer(l)
+ case l: LocalRelation =>
+ localRelations.offer(l)
+ case i: InMemoryRelation =>
+ inMemoryRelations.offer(i)
+ case p =>
+ p.expressions.foreach {
+ _.foreach {
+ case s: SubqueryExpression =>
+ s.query.foreach(collectData)
+ case _ =>
+ }
+ }
+ }
+ logicalPlan.foreach(collectData)
+
val jsonBackPlan = try {
TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
@@ -253,18 +279,15 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}
- val normalized2 = jsonBackPlan transformDown {
+ def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
case l: LogicalRDD =>
- val origin = logicalRDDs.head
- logicalRDDs = logicalRDDs.drop(1)
+ val origin = logicalRDDs.pop()
LogicalRDD(l.output, origin.rdd)(sqlContext)
case l: LocalRelation =>
- val origin = localRelations.head
- localRelations = localRelations.drop(1)
+ val origin = localRelations.pop()
l.copy(data = origin.data)
case l: InMemoryRelation =>
- val origin = inMemoryRelations.head
- inMemoryRelations = inMemoryRelations.drop(1)
+ val origin = inMemoryRelations.pop()
InMemoryRelation(
l.output,
l.useCompression,
@@ -275,7 +298,13 @@ abstract class QueryTest extends PlanTest {
origin.cachedColumnBuffers,
l._statistics,
origin._batchStats)
+ case p =>
+ p.transformExpressions {
+ case s: SubqueryExpression =>
+ s.withNewPlan(s.query.transformDown(renormalize))
+ }
}
+ val normalized2 = jsonBackPlan.transformDown(renormalize)
assert(logicalRDDs.isEmpty)
assert(localRelations.isEmpty)
@@ -309,6 +338,7 @@ object QueryTest {
* If there was exception during the execution or the contents of the DataFrame does not
* match the expected result, an error message will be returned. Otherwise, a [[None]] will
* be returned.
+ *
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
@@ -383,6 +413,7 @@ object QueryTest {
/**
* Runs the plan and makes sure the answer is within absTol of the expected result.
+ *
* @param actualAnswer the actual result in a [[Row]].
* @param expectedAnswer the expected result in a[[Row]].
* @param absTol the absolute tolerance between actual and expected answers.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 21b19fe7df..5742983fb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,38 @@ import org.apache.spark.sql.test.SharedSQLContext
class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ setupTestData()
+
+ val row = identity[(java.lang.Integer, java.lang.Double)](_)
+
+ lazy val l = Seq(
+ row(1, 2.0),
+ row(1, 2.0),
+ row(2, 1.0),
+ row(2, 1.0),
+ row(3, 3.0),
+ row(null, null),
+ row(null, 5.0),
+ row(6, null)).toDF("a", "b")
+
+ lazy val r = Seq(
+ row(2, 3.0),
+ row(2, 3.0),
+ row(3, 2.0),
+ row(4, 1.0),
+ row(null, null),
+ row(null, 5.0),
+ row(6, null)).toDF("c", "d")
+
+ lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ l.registerTempTable("l")
+ r.registerTempTable("r")
+ t.registerTempTable("t")
+ }
+
test("simple uncorrelated scalar subquery") {
assertResult(Array(Row(1))) {
sql("select (select 1 as b) as b").collect()
@@ -80,4 +112,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
" where key = (select max(key) from subqueryData) - 1)").collect()
}
}
+
+ test("EXISTS predicate subquery") {
+ checkAnswer(
+ sql("select * from l where exists(select * from r where l.a = r.c)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+ }
+
+ test("NOT EXISTS predicate subquery") {
+ checkAnswer(
+ sql("select * from l where not exists(select * from r where l.a = r.c)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
+ Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+ }
+
+ test("IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r where l.b < r.d)"),
+ Row(2, 1.0) :: Row(2, 1.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"),
+ Row(3, 3.0) :: Nil)
+ }
+
+ test("NOT IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where a not in(select c from r)"),
+ Nil)
+
+ checkAnswer(
+ sql("select * from l where a not in(select c from r where c is not null)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Nil)
+
+ checkAnswer(
+ sql("select * from l where a not in(select c from t where b < d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil)
+
+ // Empty sub-query
+ checkAnswer(
+ sql("select * from l where a not in(select c from r where c > 10 and b < d)"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) ::
+ Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+
+ }
+
+ test("complex IN predicate subquery") {
+ checkAnswer(
+ sql("select * from l where (a, b) not in(select c, d from r)"),
+ Nil)
+
+ checkAnswer(
+ sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"),
+ Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil)
+ }
}