aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala143
1 files changed, 142 insertions, 1 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 7a7d52b214..e66fe97afa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.CleanerListener
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
@@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sum
}
+ private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
+ plan.collect {
+ case InMemoryTableScanExec(_, _, relation) =>
+ getNumInMemoryTablesRecursively(relation.child) + 1
+ }.sum
+ }
+
test("withColumn doesn't invalidate cached dataframe") {
var evalCount = 0
val myUDF = udf((x: String) => { evalCount += 1; "result" })
@@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assert(spark.read.parquet(path).filter($"id" > 4).count() == 15)
}
}
+
+ test("SPARK-19993 simple subquery caching") {
+ withTempView("t1", "t2") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(2).toDF("c1").createOrReplaceTempView("t2")
+
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2)
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+
+ // Additional predicate in the subquery plan should cause a cache miss
+ val cachedMissDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t2 where c1 = 0)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedMissDs) == 0)
+ }
+ }
+
+ test("SPARK-19993 subquery caching with correlated predicates") {
+ withTempView("t1", "t2") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(1).toDF("c1").createOrReplaceTempView("t2")
+
+ // Simple correlated predicate in subquery
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+ }
+ }
+
+ test("SPARK-19993 subquery with cached underlying relation") {
+ withTempView("t1") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ spark.catalog.cacheTable("t1")
+
+ // underlying table t1 is cached as well as the query that refers to it.
+ val ds =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t1)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(ds) == 2)
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t1)
+ """.stripMargin).cache()
+ assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3)
+ }
+ }
+
+ test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") {
+ withTempView("t1", "t2", "t3", "t4") {
+ Seq(1).toDF("c1").createOrReplaceTempView("t1")
+ Seq(2).toDF("c1").createOrReplaceTempView("t2")
+ Seq(1).toDF("c1").createOrReplaceTempView("t3")
+ Seq(1).toDF("c1").createOrReplaceTempView("t4")
+
+ // Nested predicate subquery
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+ """.stripMargin).cache()
+
+ val cachedDs =
+ sql(
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs) == 1)
+
+ // Scalar subquery and predicate subquery
+ sql(
+ """
+ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+ |WHERE
+ |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+ |OR
+ |EXISTS (SELECT c1 FROM t3)
+ |OR
+ |c1 IN (SELECT c1 FROM t4)
+ """.stripMargin).cache()
+
+ val cachedDs2 =
+ sql(
+ """
+ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+ |WHERE
+ |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+ |OR
+ |EXISTS (SELECT c1 FROM t3)
+ |OR
+ |c1 IN (SELECT c1 FROM t4)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(cachedDs2) == 1)
+ }
+ }
}