aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala143
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala5
5 files changed, 198 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 59db28d58a..d7b493d521 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -47,7 +47,6 @@ abstract class SubqueryExpression(
plan: LogicalPlan,
children: Seq[Expression],
exprId: ExprId) extends PlanExpression[LogicalPlan] {
-
override lazy val resolved: Boolean = childrenResolved && plan.resolved
override lazy val references: AttributeSet =
if (plan.resolved) super.references -- plan.outputSet else super.references
@@ -59,6 +58,13 @@ abstract class SubqueryExpression(
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
}
+ def canonicalize(attrs: AttributeSeq): SubqueryExpression = {
+ // Normalize the outer references in the subquery plan.
+ val normalizedPlan = plan.transformAllExpressions {
+ case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs))
+ }
+ withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression]
+ }
}
object SubqueryExpression {
@@ -236,6 +242,12 @@ case class ScalarSubquery(
override def nullable: Boolean = true
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
+ override lazy val canonicalized: Expression = {
+ ScalarSubquery(
+ plan.canonicalized,
+ children.map(_.canonicalized),
+ ExprId(0))
+ }
}
object ScalarSubquery {
@@ -268,6 +280,12 @@ case class ListQuery(
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id} $conditionString"
+ override lazy val canonicalized: Expression = {
+ ListQuery(
+ plan.canonicalized,
+ children.map(_.canonicalized),
+ ExprId(0))
+ }
}
/**
@@ -290,4 +308,10 @@ case class Exists(
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id} $conditionString"
+ override lazy val canonicalized: Expression = {
+ Exists(
+ plan.canonicalized,
+ children.map(_.canonicalized),
+ ExprId(0))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3008e8cb84..2fb65bd435 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
// As the root of the expression, Alias will always take an arbitrary exprId, we need to
// normalize that for equality testing, by assigning expr id from 0 incrementally. The
// alias name doesn't matter and should be erased.
- Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
+ val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes)
+ Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 =>
// Top level `AttributeReference` may also be used for output like `Alias`, we should
@@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
id += 1
ar.withExprId(ExprId(id))
- case other => normalizeExprId(other)
+ case other => QueryPlan.normalizeExprId(other, allAttributes)
}.withNewChildren(canonicalizedChildren)
}
@@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
protected def preCanonicalized: PlanType = this
- /**
- * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
- * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
- * do not use `BindReferences` here as the plan may take the expression as a parameter with type
- * `Attribute`, and replace it with `BoundReference` will cause error.
- */
- protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = {
- e.transformUp {
- case ar: AttributeReference =>
- val ordinal = input.indexOf(ar.exprId)
- if (ordinal == -1) {
- ar
- } else {
- ar.withExprId(ExprId(ordinal))
- }
- }.canonicalized.asInstanceOf[T]
- }
/**
* Returns true when the given query plan will return the same results as this query plan.
@@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
}
+
+object QueryPlan {
+ /**
+ * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
+ * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
+ * do not use `BindReferences` here as the plan may take the expression as a parameter with type
+ * `Attribute`, and replace it with `BoundReference` will cause error.
+ */
+ def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = {
+ e.transformUp {
+ case s: SubqueryExpression => s.canonicalize(input)
+ case ar: AttributeReference =>
+ val ordinal = input.indexOf(ar.exprId)
+ if (ordinal == -1) {
+ ar
+ } else {
+ ar.withExprId(ExprId(ordinal))
+ }
+ }.canonicalized.asInstanceOf[T]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 3a9132d74a..866fa98533 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
@@ -516,10 +517,10 @@ case class FileSourceScanExec(
override lazy val canonicalized: FileSourceScanExec = {
FileSourceScanExec(
relation,
- output.map(normalizeExprId(_, output)),
+ output.map(QueryPlan.normalizeExprId(_, output)),
requiredSchema,
- partitionFilters.map(normalizeExprId(_, output)),
- dataFilters.map(normalizeExprId(_, output)),
+ partitionFilters.map(QueryPlan.normalizeExprId(_, output)),
+ dataFilters.map(QueryPlan.normalizeExprId(_, output)),
None)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 7a7d52b214..e66fe97afa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.CleanerListener
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
@@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sum
}
+ private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
+ plan.collect {
+ case InMemoryTableScanExec(_, _, relation) =>
+ getNumInMemoryTablesRecursively(relation.child) + 1
+ }.sum
+ }
+
test("withColumn doesn't invalidate cached dataframe") {
var evalCount = 0
val myUDF = udf((x: String) => { evalCount += 1; "result" })
@@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assert(spark.read.parquet(path).filter($"id" > 4).count() == 15)
}
}
+
+ test("SPARK-19993 simple subquery caching") {
+ withTempView("t1", "t2") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(2).toDF("c1").createOrReplaceTempView("t2")
+
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2)
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+
+ // Additional predicate in the subquery plan should cause a cache miss
+ val cachedMissDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2 where c1 = 0)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedMissDs) == 0)
+ }
+ }
+
+ test("SPARK-19993 subquery caching with correlated predicates") {
+ withTempView("t1", "t2") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(1).toDF("c1").createOrReplaceTempView("t2")
+
+ // Simple correlated predicate in subquery
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+ }
+ }
+
+ test("SPARK-19993 subquery with cached underlying relation") {
+ withTempView("t1") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ spark.catalog.cacheTable("t1")
+
+ // underlying table t1 is cached as well as the query that refers to it.
+ val ds =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t1)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(ds) == 2)
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t1)
+ """.stripMargin).cache()
+ assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3)
+ }
+ }
+
+ test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") {
+ withTempView("t1", "t2", "t3", "t4") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(2).toDF("c1").createOrReplaceTempView("t2")
+ Seq(1).toDF("c1").createOrReplaceTempView("t3")
+ Seq(1).toDF("c1").createOrReplaceTempView("t4")
+
+ // Nested predicate subquery
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+
+ // Scalar subquery and predicate subquery
+ sql(
+ """
+ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+ |WHERE
+ |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+ |OR
+ |EXISTS (SELECT c1 FROM t3)
+ |OR
+ |c1 IN (SELECT c1 FROM t4)
+ """.stripMargin).cache()
+
+ val cachedDs2 =
+ sql(
+ """
+ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+ |WHERE
+ |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+ |OR
+ |EXISTS (SELECT c1 FROM t3)
+ |OR
+ |c1 IN (SELECT c1 FROM t4)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs2) == 1)
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index fab0d7fa84..666548d1a4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogRelation
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.hive._
@@ -203,9 +204,9 @@ case class HiveTableScanExec(
override lazy val canonicalized: HiveTableScanExec = {
val input: AttributeSeq = relation.output
HiveTableScanExec(
- requestedAttributes.map(normalizeExprId(_, input)),
+ requestedAttributes.map(QueryPlan.normalizeExprId(_, input)),
relation.canonicalized.asInstanceOf[CatalogRelation],
- partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession)
+ partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession)
}
override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)