aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-01-20 14:59:30 -0800
committerReynold Xin <rxin@databricks.com>2016-01-20 14:59:30 -0800
commit8f90c151878571e20625e2a53561441ec0035dfc (patch)
treeb9b4354468e5e2f220c14ac520a960c94e0274b5
parentb7d74a602f622d8e105b349bd6d17ba42e7668dc (diff)
downloadspark-8f90c151878571e20625e2a53561441ec0035dfc.tar.gz
spark-8f90c151878571e20625e2a53561441ec0035dfc.tar.bz2
spark-8f90c151878571e20625e2a53561441ec0035dfc.zip
[SPARK-12616][SQL] Making Logical Operator `Union` Support Arbitrary Number of Children
The existing `Union` logical operator only supports two children. Thus, adding a new logical operator `Unions` which can have arbitrary number of children to replace the existing one. `Union` logical plan is a binary node. However, a typical use case for union is to union a very large number of input sources (DataFrames, RDDs, or files). It is not uncommon to union hundreds of thousands of files. In this case, our optimizer can become very slow due to the large number of logical unions. We should change the Union logical plan to support an arbitrary number of children, and add a single rule in the optimizer to collapse all adjacent `Unions` into a single `Unions`. Note that this problem doesn't exist in physical plan, because the physical `Unions` already supports arbitrary number of children. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #10577 from gatorsmile/unionAllMultiChildren.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala47
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala82
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala)43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala11
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala4
20 files changed, 322 insertions, 122 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 5fb41f7e4b..35273c7e24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -402,8 +402,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
overwrite)
}
- // If there are multiple INSERTS just UNION them together into on query.
- val query = queries.reduceLeft(Union)
+ // If there are multiple INSERTS just UNION them together into one query.
+ val query = if (queries.length == 1) queries.head else Union(queries)
// return With plan if there is CTE
cteRelations.map(With(query, _)).getOrElse(query)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d4b4bc88b3..33d76eeb21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -66,7 +66,8 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
CTESubstitution,
- WindowsSubstitution),
+ WindowsSubstitution,
+ EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
@@ -1171,6 +1172,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
}
/**
+ * Removes [[Union]] operators from the plan if it just has one child.
+ */
+object EliminateUnions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Union(children) if children.size == 1 => children.head
+ }
+}
+
+/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
* Window(window expressions).
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 2a2e0d27d9..f2e78d9744 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -189,6 +189,14 @@ trait CheckAnalysis {
s"but the left table has ${left.output.length} columns and the right has " +
s"${right.output.length}")
+ case s: Union if s.children.exists(_.output.length != s.children.head.output.length) =>
+ val firstError = s.children.find(_.output.length != s.children.head.output.length).get
+ failAnalysis(
+ s"""
+ |Unions can only be performed on tables with the same number of columns,
+ | but one table has '${firstError.output.length}' columns and another table has
+ | '${s.children.head.output.length}' columns""".stripMargin)
+
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 7df3787e6d..c557c32319 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
+import scala.annotation.tailrec
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -27,7 +30,7 @@ import org.apache.spark.sql.types._
/**
- * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in
+ * A collection of [[Rule]] that can be used to coerce differing types that participate in
* operations into compatible ones.
*
* Most of these rules are based on Hive semantics, but they do not introduce any dependencies on
@@ -219,31 +222,59 @@ object HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p
- case s @ SetOperation(left, right) if s.childrenResolved
- && left.output.length == right.output.length && !s.resolved =>
+ case s @ SetOperation(left, right) if s.childrenResolved &&
+ left.output.length == right.output.length && !s.resolved =>
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ assert(newChildren.length == 2)
+ s.makeCopy(Array(newChildren.head, newChildren.last))
- // Tracks the list of data types to widen.
- // Some(dataType) means the right-hand side and the left-hand side have different types,
- // and there is a target type to widen both sides to.
- val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map {
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- findWiderTypeForTwo(lhs.dataType, rhs.dataType)
- case other => None
- }
+ case s: Union if s.childrenResolved &&
+ s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
+ s.makeCopy(Array(newChildren))
+ }
- if (targetTypes.exists(_.isDefined)) {
- // There is at least one column to widen.
- s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes)))
- } else {
- // If we cannot find any column to widen, then just return the original set.
- s
- }
+ /** Build new children with the widest types for each attribute among all the children */
+ private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+ require(children.forall(_.output.length == children.head.output.length))
+
+ // Get a sequence of data types, each of which is the widest type of this specific attribute
+ // in all the children
+ val targetTypes: Seq[DataType] =
+ getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
+
+ if (targetTypes.nonEmpty) {
+ // Add an extra Project if the targetTypes are different from the original types.
+ children.map(widenTypes(_, targetTypes))
+ } else {
+ // Unable to find a target type to widen, then just return the original set.
+ children
+ }
+ }
+
+ /** Get the widest type for each attribute in all the children */
+ @tailrec private def getWidestTypes(
+ children: Seq[LogicalPlan],
+ attrIndex: Int,
+ castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
+ // Return the result after the widen data types have been found for all the children
+ if (attrIndex >= children.head.output.length) return castedTypes.toSeq
+
+ // For the attrIndex-th attribute, find the widest type
+ findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
+ // If unable to find an appropriate widen type for this column, return an empty Seq
+ case None => Seq.empty[DataType]
+ // Otherwise, record the result in the queue and find the type for the next column
+ case Some(widenType) =>
+ castedTypes.enqueue(widenType)
+ getWidestTypes(children, attrIndex + 1, castedTypes)
+ }
}
/** Given a plan, add an extra project on top to widen some columns' data types. */
- private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = {
+ private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
- case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
+ case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 04643f0274..44455b4820 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -45,6 +45,13 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
+ // - Do the first call of CombineUnions before starting the major Optimizer rules,
+ // since it can reduce the number of iteration and the other rules could add/move
+ // extra operators between two adjacent Union operators.
+ // - Call CombineUnions again in Batch("Operator Optimizations"),
+ // since the other rules might make two separate Unions operators adjacent.
+ Batch("Union", Once,
+ CombineUnions) ::
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
@@ -62,6 +69,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ProjectCollapsing,
CombineFilters,
CombineLimits,
+ CombineUnions,
// Constant folding and strength reduction
NullPropagation,
OptimizeIn,
@@ -138,11 +146,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
- private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = {
- assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except])
- assert(bn.left.output.size == bn.right.output.size)
-
- AttributeMap(bn.left.output.zip(bn.right.output))
+ private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = {
+ assert(left.output.size == right.output.size)
+ AttributeMap(left.output.zip(right.output))
}
/**
@@ -176,32 +182,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // Push down filter into union
- case Filter(condition, u @ Union(left, right)) =>
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(u)
- Filter(nondeterministic,
- Union(
- Filter(deterministic, left),
- Filter(pushToRight(deterministic, rewrites), right)
- )
- )
// Push down deterministic projection through UNION ALL
- case p @ Project(projectList, u @ Union(left, right)) =>
+ case p @ Project(projectList, Union(children)) =>
+ assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
- val rewrites = buildRewrites(u)
- Union(
- Project(projectList, left),
- Project(projectList.map(pushToRight(_, rewrites)), right))
+ val newFirstChild = Project(projectList, children.head)
+ val newOtherChildren = children.tail.map ( child => {
+ val rewrites = buildRewrites(children.head, child)
+ Project(projectList.map(pushToRight(_, rewrites)), child)
+ } )
+ Union(newFirstChild +: newOtherChildren)
} else {
p
}
+ // Push down filter into union
+ case Filter(condition, Union(children)) =>
+ assert(children.nonEmpty)
+ val (deterministic, nondeterministic) = partitionByDeterministic(condition)
+ val newFirstChild = Filter(deterministic, children.head)
+ val newOtherChildren = children.tail.map {
+ child => {
+ val rewrites = buildRewrites(children.head, child)
+ Filter(pushToRight(deterministic, rewrites), child)
+ }
+ }
+ Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
+
// Push down filter through INTERSECT
- case Filter(condition, i @ Intersect(left, right)) =>
+ case Filter(condition, Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(i)
+ val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
@@ -210,9 +222,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
)
// Push down filter through EXCEPT
- case Filter(condition, e @ Except(left, right)) =>
+ case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(e)
+ val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Except(
Filter(deterministic, left),
@@ -663,6 +675,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Combines all adjacent [[Union]] operators into a single [[Union]].
+ */
+object CombineUnions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Unions(children) => Union(children)
+ }
+}
+
+/**
* Combines two adjacent [[Filter]] operators into one, merging the
* conditions into one conjunctive predicate.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cd3f15cbe1..f0ee124e88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.planning
+import scala.annotation.tailrec
+import scala.collection.mutable
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -170,17 +173,29 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}
+
/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
object Unions {
def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
- case u: Union => Some(collectUnionChildren(u))
+ case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
case _ => None
}
- private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
- case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r)
- case other => other :: Nil
+ // Doing a depth-first tree traversal to combine all the union children.
+ @tailrec
+ private def collectUnionChildren(
+ plans: mutable.Stack[LogicalPlan],
+ children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+ if (plans.isEmpty) children
+ else {
+ plans.pop match {
+ case Union(grandchildren) =>
+ grandchildren.reverseMap(plans.push(_))
+ collectUnionChildren(plans, children)
+ case other => collectUnionChildren(plans, children :+ other)
+ }
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index f4a3d85d2a..e9c970cd08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -101,19 +101,6 @@ private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}
-case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
-
- override def output: Seq[Attribute] =
- left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
- leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
- }
-
- override def statistics: Statistics = {
- val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
- Statistics(sizeInBytes = sizeInBytes)
- }
-}
-
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
override def output: Seq[Attribute] =
@@ -127,6 +114,40 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override def output: Seq[Attribute] = left.output
}
+/** Factory for constructing new `Union` nodes. */
+object Union {
+ def apply(left: LogicalPlan, right: LogicalPlan): Union = {
+ Union (left :: right :: Nil)
+ }
+}
+
+case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
+
+ // updating nullability to make all the children consistent
+ override def output: Seq[Attribute] =
+ children.map(_.output).transpose.map(attrs =>
+ attrs.head.withNullability(attrs.exists(_.nullable)))
+
+ override lazy val resolved: Boolean = {
+ // allChildrenCompatible needs to be evaluated after childrenResolved
+ def allChildrenCompatible: Boolean =
+ children.tail.forall( child =>
+ // compare the attribute number with the first child
+ child.output.length == children.head.output.length &&
+ // compare the data types with the first child
+ child.output.zip(children.head.output).forall {
+ case (l, r) => l.dataType == r.dataType }
+ )
+
+ children.length > 1 && childrenResolved && allChildrenCompatible
+ }
+
+ override def statistics: Statistics = {
+ val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
+ Statistics(sizeInBytes = sizeInBytes)
+ }
+}
+
case class Join(
left: LogicalPlan,
right: LogicalPlan,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 975cd87d09..ab68028220 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
+ test("Eliminate the unnecessary union") {
+ val plan = Union(testRelation :: Nil)
+ val expected = testRelation
+ checkAnalysis(plan, expected)
+ }
+
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 39c8f56c1b..24c608eaa5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer.execute(plan).collect {
- case Union(left, right) => (left.output.head, right.output.head)
+ case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index b326aa9c55..c30434a006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -387,19 +387,19 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
- test("WidenSetOperationTypes for union, except, and intersect") {
- def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
- logical.output.zip(expectTypes).foreach { case (attr, dt) =>
- assert(attr.dataType === dt)
- }
+ private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+ logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+ assert(attr.dataType === dt)
}
+ }
- val left = LocalRelation(
+ test("WidenSetOperationTypes for except and intersect") {
+ val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
- val right = LocalRelation(
+ val secondTable = LocalRelation(
AttributeReference("s", StringType)(),
AttributeReference("d", DecimalType(2, 1))(),
AttributeReference("f", FloatType)(),
@@ -408,15 +408,65 @@ class HiveTypeCoercionSuite extends PlanTest {
val wt = HiveTypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val r1 = wt(Union(left, right)).asInstanceOf[Union]
- val r2 = wt(Except(left, right)).asInstanceOf[Except]
- val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
+ val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
+ val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
checkOutput(r2.right, expectedTypes)
- checkOutput(r3.left, expectedTypes)
- checkOutput(r3.right, expectedTypes)
+
+ // Check if a Project is added
+ assert(r1.left.isInstanceOf[Project])
+ assert(r1.right.isInstanceOf[Project])
+ assert(r2.left.isInstanceOf[Project])
+ assert(r2.right.isInstanceOf[Project])
+
+ val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except]
+ checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
+ checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
+
+ // Check if no Project is added
+ assert(r3.left.isInstanceOf[LocalRelation])
+ assert(r3.right.isInstanceOf[LocalRelation])
+ }
+
+ test("WidenSetOperationTypes for union") {
+ val firstTable = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("b", ByteType)(),
+ AttributeReference("d", DoubleType)())
+ val secondTable = LocalRelation(
+ AttributeReference("s", StringType)(),
+ AttributeReference("d", DecimalType(2, 1))(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("l", LongType)())
+ val thirdTable = LocalRelation(
+ AttributeReference("m", StringType)(),
+ AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("p", FloatType)(),
+ AttributeReference("q", DoubleType)())
+ val forthTable = LocalRelation(
+ AttributeReference("m", StringType)(),
+ AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("p", ByteType)(),
+ AttributeReference("q", DoubleType)())
+
+ val wt = HiveTypeCoercion.WidenSetOperationTypes
+ val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
+
+ val unionRelation = wt(
+ Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
+ assert(unionRelation.children.length == 4)
+ checkOutput(unionRelation.children.head, expectedTypes)
+ checkOutput(unionRelation.children(1), expectedTypes)
+ checkOutput(unionRelation.children(2), expectedTypes)
+ checkOutput(unionRelation.children(3), expectedTypes)
+
+ assert(unionRelation.children.head.isInstanceOf[Project])
+ assert(unionRelation.children(1).isInstanceOf[Project])
+ assert(unionRelation.children(2).isInstanceOf[Project])
+ assert(unionRelation.children(3).isInstanceOf[Project])
}
test("Transform Decimal precision/scale for union except and intersect") {
@@ -438,8 +488,8 @@ class HiveTypeCoercionSuite extends PlanTest {
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
- checkOutput(r1.left, expectedType1)
- checkOutput(r1.right, expectedType1)
+ checkOutput(r1.children.head, expectedType1)
+ checkOutput(r1.children.last, expectedType1)
checkOutput(r2.left, expectedType1)
checkOutput(r2.right, expectedType1)
checkOutput(r3.left, expectedType1)
@@ -459,7 +509,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
- checkOutput(r1.right, Seq(expectedType))
+ checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))
@@ -467,7 +517,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
- checkOutput(r4.left, Seq(expectedType))
+ checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
checkOutput(r6.left, Seq(expectedType))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index a498b463a6..2283f7c008 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -24,48 +24,73 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-class SetOperationPushDownSuite extends PlanTest {
+class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
+ CombineUnions,
SetOperationPushDown,
SimplifyFilters) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
- val testUnion = Union(testRelation, testRelation2)
+ val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
+ val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
val testIntersect = Intersect(testRelation, testRelation2)
val testExcept = Except(testRelation, testRelation2)
- test("union/intersect/except: filter to each side") {
- val unionQuery = testUnion.where('a === 1)
+ test("union: combine unions into one unions") {
+ val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
+ val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation))
+ val unionOptimized1 = Optimize.execute(unionQuery1.analyze)
+ val unionOptimized2 = Optimize.execute(unionQuery2.analyze)
+
+ comparePlans(unionOptimized1, unionOptimized2)
+
+ val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil)
+ val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze)
+ val unionQuery3 = Union(unionQuery1, unionQuery2)
+ val unionOptimized3 = Optimize.execute(unionQuery3.analyze)
+ comparePlans(combinedUnionsOptimized, unionOptimized3)
+ }
+
+ test("intersect/except: filter to each side") {
val intersectQuery = testIntersect.where('b < 10)
val exceptQuery = testExcept.where('c >= 5)
- val unionOptimized = Optimize.execute(unionQuery.analyze)
val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
- val unionCorrectAnswer =
- Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze
val intersectCorrectAnswer =
Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
val exceptCorrectAnswer =
Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
- comparePlans(unionOptimized, unionCorrectAnswer)
comparePlans(intersectOptimized, intersectCorrectAnswer)
comparePlans(exceptOptimized, exceptCorrectAnswer)
}
+ test("union: filter to each side") {
+ val unionQuery = testUnion.where('a === 1)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Union(testRelation.where('a === 1) ::
+ testRelation2.where('d === 1) ::
+ testRelation3.where('g === 1) :: Nil).analyze
+
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
test("union: project to each side") {
val unionQuery = testUnion.select('a)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
- Union(testRelation.select('a), testRelation2.select('d)).analyze
+ Union(testRelation.select('a) ::
+ testRelation2.select('d) ::
+ testRelation3.select('g) :: Nil).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 95e5fbb119..518f9dcf94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
@@ -1002,7 +1003,9 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def unionAll(other: DataFrame): DataFrame = withPlan {
- Union(logicalPlan, other.logicalPlan)
+ // This breaks caching, but it's usually ok because it addresses a very specific use case:
+ // using union to union many files or partitions.
+ CombineUnions(Union(logicalPlan, other.logicalPlan))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9a9f7d111c..bd99c39957 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
-import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
@@ -603,7 +604,11 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analogous to `UNION ALL` in SQL.
* @since 1.6.0
*/
- def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
+ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) =>
+ // This breaks caching, but it's usually ok because it addresses a very specific use case:
+ // using union to union many files or partitions.
+ CombineUnions(Union(left, right))
+ }
/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c4ddb6d76b..60fbb595e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
- case Unions(unionChildren) =>
+ case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
execution.Except(planLater(left), planLater(right)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 9e2e0357c6..6deb72adad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -281,13 +281,10 @@ case class Range(
* Union two plans, without a distinct. This is UNION ALL in SQL.
*/
case class Union(children: Seq[SparkPlan]) extends SparkPlan {
- override def output: Seq[Attribute] = {
- children.tail.foldLeft(children.head.output) { case (currentOutput, child) =>
- currentOutput.zip(child.output).map { case (a1, a2) =>
- a1.withNullability(a1.nullable || a2.nullable)
- }
- }
- }
+ override def output: Seq[Attribute] =
+ children.map(_.output).transpose.map(attrs =>
+ attrs.head.withNullability(attrs.exists(_.nullable)))
+
protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 1a3df1b117..3c0f25a5dc 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
- Dataset<String> unioned = ds.union(ds2);
+ Dataset<String> unioned = ds.union(ds2).union(ds);
Assert.assertEquals(
- Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"),
+ Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"),
unioned.collectAsList());
Dataset<String> subtracted = ds.subtract(ds2);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index bd11a387a1..09bbe57a43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -25,7 +25,7 @@ import scala.util.Random
import org.scalatest.Matchers._
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions._
@@ -98,6 +98,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.collect().toSeq)
}
+ test("union all") {
+ val unionDF = testData.unionAll(testData).unionAll(testData)
+ .unionAll(testData).unionAll(testData)
+
+ // Before optimizer, Union should be combined.
+ assert(unionDF.queryExecution.analyzed.collect {
+ case j: Union if j.children.size == 5 => j }.size === 1)
+
+ checkAnswer(
+ unionDF.agg(avg('key), max('key), min('key), sum('key)),
+ Row(50.5, 100, 1, 25250) :: Nil
+ )
+ }
+
test("empty data frame") {
assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(sqlContext.emptyDataFrame.count() === 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 49feeaf17d..8fca5e2167 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}
- test("unions are collapsed") {
- val planner = sqlContext.planner
- import planner._
- val query = testData.unionAll(testData).unionAll(testData).logicalPlan
- val planned = BasicOperators(query).head
- val logicalUnions = query collect { case u: logical.Union => u }
- val physicalUnions = planned collect { case u: execution.Union => u }
-
- assert(logicalUnions.size === 2)
- assert(physicalUnions.size === 1)
- }
-
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index e83b4bffff..1654594538 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -129,11 +129,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
conditionSQL = condition.sql
} yield s"$childSQL $whereOrHaving $conditionSQL"
- case Union(left, right) =>
- for {
- leftSQL <- toSQL(left)
- rightSQL <- toSQL(right)
- } yield s"$leftSQL UNION ALL $rightSQL"
+ case Union(children) if children.length > 1 =>
+ val childrenSql = children.map(toSQL(_))
+ if (childrenSql.exists(_.isEmpty)) {
+ None
+ } else {
+ Some(childrenSql.map(_.get).mkString(" UNION ALL "))
+ }
// Persisted data source relation
case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index 0604d9f47c..261a4746f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -105,6 +105,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
}
+ test("three-child union") {
+ checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0")
+ }
+
test("case") {
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
}