aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala25
4 files changed, 40 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c8e9d8e2f9..fe328fd598 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -76,7 +76,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Operator Optimizations", fixedPoint,
// Operator push down
PushThroughSetOperations,
- PushProjectThroughSample,
ReorderJoin,
EliminateOuterJoin,
PushPredicateThroughJoin,
@@ -149,17 +148,6 @@ class SimpleTestOptimizer extends Optimizer(
new SimpleCatalystConf(caseSensitiveAnalysis = true))
/**
- * Pushes projects down beneath Sample to enable column pruning with sampling.
- */
-object PushProjectThroughSample extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // Push down projection into sample
- case Project(projectList, Sample(lb, up, replace, seed, child)) =>
- Sample(lb, up, replace, seed, Project(projectList, child))()
- }
-}
-
-/**
* Removes the Project only conducting Alias of its child node.
* It is created mainly for removing extra Project added in EliminateSerialization rule,
* but can also benefit other operators.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index b5664a5e69..589607e3ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -346,5 +346,20 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(Optimize.execute(plan1.analyze), correctAnswer1)
}
+ test("push project down into sample") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val x = testRelation.subquery('x)
+
+ val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))()
+ comparePlans(optimized1, expected1.analyze)
+
+ val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa)
+ comparePlans(optimized2, expected2.analyze)
+ }
+
// todo: add more tests for column pruning
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 780e78ed1c..596b8fcea1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -34,7 +34,6 @@ class FilterPushdownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(10),
- PushProjectThroughSample,
CombineFilters,
PushDownPredicate,
BooleanSimplification,
@@ -585,22 +584,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
- test("push project and filter down into sample") {
- val x = testRelation.subquery('x)
- val originalQuery =
- Sample(0.0, 0.6, false, 11L, x)().select('a)
-
- val originalQueryAnalyzed =
- EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery))
-
- val optimized = Optimize.execute(originalQueryAnalyzed)
-
- val correctAnswer =
- Sample(0.0, 0.6, false, 11L, x.select('a))()
-
- comparePlans(optimized, correctAnswer.analyze)
- }
-
test("aggregate: push down filter when filter on group by expression") {
val originalQuery = testRelation
.groupBy('a)('a, count('b) as 'c)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0b6f40872f..7e3b7b63d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -422,6 +422,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
3, 17, 27, 58, 62)
}
+ test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") {
+ val simpleUdf = udf((n: Int) => {
+ require(n != 1, "simpleUdf shouldn't see id=1!")
+ 1
+ })
+
+ val df = Seq(
+ (0, "string0"),
+ (1, "string1"),
+ (2, "string2"),
+ (3, "string3"),
+ (4, "string4"),
+ (5, "string5"),
+ (6, "string6"),
+ (7, "string7"),
+ (8, "string8"),
+ (9, "string9")
+ ).toDF("id", "stringData")
+ val sampleDF = df.sample(false, 0.7, 50)
+ // After sampling, sampleDF doesn't contain id=1.
+ assert(!sampleDF.select("id").collect.contains(1))
+ // simpleUdf should not encounter id=1.
+ checkAnswer(sampleDF.select(simpleUdf($"id")), List.fill(sampleDF.count.toInt)(Row(1)))
+ }
+
test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")