diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-09-01 13:19:15 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-09-01 13:19:15 +0800 |
commit | aaf632b2132750c697dddd0469b902d9308dbf36 (patch) | |
tree | 45f8c6d5d852f2ec8ad8b100969c482b18a8b68f /sql/core/src/test/scala | |
parent | 7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2 (diff) | |
download | spark-aaf632b2132750c697dddd0469b902d9308dbf36.tar.gz spark-aaf632b2132750c697dddd0469b902d9308dbf36.tar.bz2 spark-aaf632b2132750c697dddd0469b902d9308dbf36.zip |
revert PR#10896 and PR#14865
## What changes were proposed in this pull request?
according to the discussion in the original PR #10896 and the new approach PR #14876 , we decided to revert these 2 PRs and go with the new approach.
## How was this patch tested?
N/A
Author: Wenchen Fan <wenchen@databricks.com>
Closes #14909 from cloud-fan/revert.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 15 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala | 77 |
2 files changed, 21 insertions, 71 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ce0b92a461..f89951760f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } /** - * Verifies that there is a single Aggregation for `df` + * Verifies that there is no Exchange between the Aggregations for `df` */ - private def verifyNonExchangingSingleAgg(df: DataFrame) = { + private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { case agg: HashAggregateExec => + atFirstAgg = !atFirstAgg + case _ => if (atFirstAgg) { - fail("Should not have back to back Aggregates") + fail("Should not have operators between the two aggregations") } - atFirstAgg = true - case _ => } } @@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates val df3 = testData.repartition($"key").groupBy("key").count() - verifyNonExchangingSingleAgg(df3) - verifyNonExchangingSingleAgg(testData.repartition($"key", $"value") + verifyNonExchangingAgg(df3) + verifyNonExchangingAgg(testData.repartition($"key", $"value") .groupBy("key", "value").count()) - verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count()) // Grouping by just the first distributeBy expr, need to exchange. verifyExchangingAgg(testData.repartition($"key", $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b0aa3378e5..375da224aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -38,84 +37,36 @@ class PlannerSuite extends SharedSQLContext { setupTestData() - private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = { + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = spark.sessionState.planner import planner._ - val ensureRequirements = EnsureRequirements(spark.sessionState.conf) - val planned = Aggregation(query).headOption.map(ensureRequirements(_)) - .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?")) - planned.collect { case n if n.nodeName contains "Aggregate" => n } + val plannedOption = Aggregation(query).headOption + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be four aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 4, + s"The plan of query $query does not have partial aggregations.") } test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - assert(testPartialAggregationPlan(query).size == 2, - s"The plan of query $query does not have partial aggregations.") + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed testPartialAggregationPlan(query) - // For the new aggregation code path, there will be four aggregate operator for distinct - // aggregations. - assert(testPartialAggregationPlan(query).size == 4, - s"The plan of query $query does not have partial aggregations.") } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - // For the new aggregation code path, there will be four aggregate operator for distinct - // aggregations. - assert(testPartialAggregationPlan(query).size == 4, - s"The plan of query $query does not have partial aggregations.") - } - - test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") { - withTempView("testSortBasedPartialAggregation") { - val schema = StructType( - StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil) - val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString))) - spark.createDataFrame(rowRDD, schema) - .createOrReplaceTempView("testSortBasedPartialAggregation") - - // This test assumes a query below uses sort-based aggregations - val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key") - .queryExecution.executedPlan - // This line extracts both SortAggregate and Sort operators - val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n } - val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n } - assert(extractedOps.size == 4 && aggOps.size == 2, - s"The plan $planned does not have correct sort-based partial aggregate pairs.") - } - } - - test("non-partial aggregation for aggregates") { - withTempView("testNonPartialAggregation") { - val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) - val row = Row.fromSeq(Seq.fill(1)(null)) - val rowRDD = sparkContext.parallelize(row :: Nil) - spark.createDataFrame(rowRDD, schema).repartition($"value") - .createOrReplaceTempView("testNonPartialAggregation") - - val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value") - .queryExecution.executedPlan - - // If input data are already partitioned and the same columns are used in grouping keys and - // aggregation values, no partial aggregation exist in query plans. - val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.") - - val planned2 = sql( - """ - |SELECT t.value, SUM(DISTINCT t.value) - |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t - |GROUP BY t.value - """.stripMargin).queryExecution.executedPlan - - val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } - assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.") - } + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { |