aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala76
1 files changed, 69 insertions, 7 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b84ae7c5bb..df7529d83f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
- Batch("Filter Pushdown", Once,
+ Batch("Filter Pushdown", FixedPoint(10),
SamplePushDown,
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
CollapseProject) :: Nil
}
@@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a === 3)
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
@@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a + 1 < 3)
+ .select('a, 'b)
.groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
.where('c === 2L || 'aa > 4)
.analyze
@@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where("s" === "s")
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c, "s" as 'd)
.where('c === 2L)
.analyze
@@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("broadcast hint") {
+ val originalQuery = BroadcastHint(testRelation)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("union") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Union(Seq(testRelation, testRelation2))
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Union(Seq(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L)))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("intersect") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Intersect(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Intersect(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("except") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Except(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Except(
+ testRelation.where('a === 2L),
+ testRelation2)
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}