From 8f90c151878571e20625e2a53561441ec0035dfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 20 Jan 2016 14:59:30 -0800 Subject: [SPARK-12616][SQL] Making Logical Operator `Union` Support Arbitrary Number of Children The existing `Union` logical operator only supports two children. Thus, adding a new logical operator `Unions` which can have arbitrary number of children to replace the existing one. `Union` logical plan is a binary node. However, a typical use case for union is to union a very large number of input sources (DataFrames, RDDs, or files). It is not uncommon to union hundreds of thousands of files. In this case, our optimizer can become very slow due to the large number of logical unions. We should change the Union logical plan to support an arbitrary number of children, and add a single rule in the optimizer to collapse all adjacent `Unions` into a single `Unions`. Note that this problem doesn't exist in physical plan, because the physical `Unions` already supports arbitrary number of children. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10577 from gatorsmile/unionAllMultiChildren. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++++- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 9 +++++++-- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/execution/basicOperators.scala | 11 ++++------- .../java/test/org/apache/spark/sql/JavaDatasetSuite.java | 4 ++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 16 +++++++++++++++- .../org/apache/spark/sql/execution/PlannerSuite.scala | 12 ------------ 7 files changed, 33 insertions(+), 26 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 95e5fbb119..518f9dcf94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} @@ -1002,7 +1003,9 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9a9f7d111c..bd99c39957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ +import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} @@ -603,7 +604,11 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(left, right)) + } /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4ddb6d76b..60fbb595e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil - case Unions(unionChildren) => + case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 9e2e0357c6..6deb72adad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -281,13 +281,10 @@ case class Range( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - override def output: Seq[Attribute] = { - children.tail.foldLeft(children.head.output) { case (currentOutput, child) => - currentOutput.zip(child.output).map { case (a1, a2) => - a1.withNullability(a1.nullable || a2.nullable) - } - } - } + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 1a3df1b117..3c0f25a5dc 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable { Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); - Dataset unioned = ds.union(ds2); + Dataset unioned = ds.union(ds2).union(ds); Assert.assertEquals( - Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), unioned.collectAsList()); Dataset subtracted = ds.subtract(ds2); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bd11a387a1..09bbe57a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -98,6 +98,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } + test("union all") { + val unionDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 49feeaf17d..8fca5e2167 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } - test("unions are collapsed") { - val planner = sqlContext.planner - import planner._ - val query = testData.unionAll(testData).unionAll(testData).logicalPlan - val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u } - val physicalUnions = planned collect { case u: execution.Union => u } - - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) - } - test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed testPartialAggregationPlan(query) -- cgit v1.2.3