aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala47
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala82
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala)43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala11
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala4
20 files changed, 322 insertions, 122 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 5fb41f7e4b..35273c7e24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -402,8 +402,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
overwrite)
}
- // If there are multiple INSERTS just UNION them together into on query.
- val query = queries.reduceLeft(Union)
+ // If there are multiple INSERTS just UNION them together into one query.
+ val query = if (queries.length == 1) queries.head else Union(queries)
// return With plan if there is CTE
cteRelations.map(With(query, _)).getOrElse(query)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d4b4bc88b3..33d76eeb21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -66,7 +66,8 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
CTESubstitution,
- WindowsSubstitution),
+ WindowsSubstitution,
+ EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
@@ -1171,6 +1172,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
}
/**
+ * Removes [[Union]] operators from the plan if it just has one child.
+ */
+object EliminateUnions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Union(children) if children.size == 1 => children.head
+ }
+}
+
+/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
* Window(window expressions).
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 2a2e0d27d9..f2e78d9744 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -189,6 +189,14 @@ trait CheckAnalysis {
s"but the left table has ${left.output.length} columns and the right has " +
s"${right.output.length}")
+ case s: Union if s.children.exists(_.output.length != s.children.head.output.length) =>
+ val firstError = s.children.find(_.output.length != s.children.head.output.length).get
+ failAnalysis(
+ s"""
+ |Unions can only be performed on tables with the same number of columns,
+ | but one table has '${firstError.output.length}' columns and another table has
+ | '${s.children.head.output.length}' columns""".stripMargin)
+
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 7df3787e6d..c557c32319 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
+import scala.annotation.tailrec
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -27,7 +30,7 @@ import org.apache.spark.sql.types._
/**
- * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in
+ * A collection of [[Rule]] that can be used to coerce differing types that participate in
* operations into compatible ones.
*
* Most of these rules are based on Hive semantics, but they do not introduce any dependencies on
@@ -219,31 +222,59 @@ object HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p
- case s @ SetOperation(left, right) if s.childrenResolved
- && left.output.length == right.output.length && !s.resolved =>
+ case s @ SetOperation(left, right) if s.childrenResolved &&
+ left.output.length == right.output.length && !s.resolved =>
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ assert(newChildren.length == 2)
+ s.makeCopy(Array(newChildren.head, newChildren.last))
- // Tracks the list of data types to widen.
- // Some(dataType) means the right-hand side and the left-hand side have different types,
- // and there is a target type to widen both sides to.
- val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map {
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- findWiderTypeForTwo(lhs.dataType, rhs.dataType)
- case other => None
- }
+ case s: Union if s.childrenResolved &&
+ s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
+ s.makeCopy(Array(newChildren))
+ }
- if (targetTypes.exists(_.isDefined)) {
- // There is at least one column to widen.
- s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes)))
- } else {
- // If we cannot find any column to widen, then just return the original set.
- s
- }
+ /** Build new children with the widest types for each attribute among all the children */
+ private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+ require(children.forall(_.output.length == children.head.output.length))
+
+ // Get a sequence of data types, each of which is the widest type of this specific attribute
+ // in all the children
+ val targetTypes: Seq[DataType] =
+ getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
+
+ if (targetTypes.nonEmpty) {
+ // Add an extra Project if the targetTypes are different from the original types.
+ children.map(widenTypes(_, targetTypes))
+ } else {
+ // Unable to find a target type to widen, then just return the original set.
+ children
+ }
+ }
+
+ /** Get the widest type for each attribute in all the children */
+ @tailrec private def getWidestTypes(
+ children: Seq[LogicalPlan],
+ attrIndex: Int,
+ castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
+ // Return the result after the widen data types have been found for all the children
+ if (attrIndex >= children.head.output.length) return castedTypes.toSeq
+
+ // For the attrIndex-th attribute, find the widest type
+ findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
+ // If unable to find an appropriate widen type for this column, return an empty Seq
+ case None => Seq.empty[DataType]
+ // Otherwise, record the result in the queue and find the type for the next column
+ case Some(widenType) =>
+ castedTypes.enqueue(widenType)
+ getWidestTypes(children, attrIndex + 1, castedTypes)
+ }
}
/** Given a plan, add an extra project on top to widen some columns' data types. */
- private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = {
+ private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
- case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
+ case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 04643f0274..44455b4820 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -45,6 +45,13 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
+ // - Do the first call of CombineUnions before starting the major Optimizer rules,
+ // since it can reduce the number of iteration and the other rules could add/move
+ // extra operators between two adjacent Union operators.
+ // - Call CombineUnions again in Batch("Operator Optimizations"),
+ // since the other rules might make two separate Unions operators adjacent.
+ Batch("Union", Once,
+ CombineUnions) ::
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
@@ -62,6 +69,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ProjectCollapsing,
CombineFilters,
CombineLimits,
+ CombineUnions,
// Constant folding and strength reduction
NullPropagation,
OptimizeIn,
@@ -138,11 +146,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
- private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = {
- assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except])
- assert(bn.left.output.size == bn.right.output.size)
-
- AttributeMap(bn.left.output.zip(bn.right.output))
+ private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = {
+ assert(left.output.size == right.output.size)
+ AttributeMap(left.output.zip(right.output))
}
/**
@@ -176,32 +182,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // Push down filter into union
- case Filter(condition, u @ Union(left, right)) =>
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(u)
- Filter(nondeterministic,
- Union(
- Filter(deterministic, left),
- Filter(pushToRight(deterministic, rewrites), right)
- )
- )
// Push down deterministic projection through UNION ALL
- case p @ Project(projectList, u @ Union(left, right)) =>
+ case p @ Project(projectList, Union(children)) =>
+ assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
- val rewrites = buildRewrites(u)
- Union(
- Project(projectList, left),
- Project(projectList.map(pushToRight(_, rewrites)), right))
+ val newFirstChild = Project(projectList, children.head)
+ val newOtherChildren = children.tail.map ( child => {
+ val rewrites = buildRewrites(children.head, child)
+ Project(projectList.map(pushToRight(_, rewrites)), child)
+ } )
+ Union(newFirstChild +: newOtherChildren)
} else {
p
}
+ // Push down filter into union
+ case Filter(condition, Union(children)) =>
+ assert(children.nonEmpty)
+ val (deterministic, nondeterministic) = partitionByDeterministic(condition)
+ val newFirstChild = Filter(deterministic, children.head)
+ val newOtherChildren = children.tail.map {
+ child => {
+ val rewrites = buildRewrites(children.head, child)
+ Filter(pushToRight(deterministic, rewrites), child)
+ }
+ }
+ Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
+
// Push down filter through INTERSECT
- case Filter(condition, i @ Intersect(left, right)) =>
+ case Filter(condition, Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(i)
+ val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
@@ -210,9 +222,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
)
// Push down filter through EXCEPT
- case Filter(condition, e @ Except(left, right)) =>
+ case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(e)
+ val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Except(
Filter(deterministic, left),
@@ -663,6 +675,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Combines all adjacent [[Union]] operators into a single [[Union]].
+ */
+object CombineUnions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Unions(children) => Union(children)
+ }
+}
+
+/**
* Combines two adjacent [[Filter]] operators into one, merging the
* conditions into one conjunctive predicate.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cd3f15cbe1..f0ee124e88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.planning
+import scala.annotation.tailrec
+import scala.collection.mutable
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -170,17 +173,29 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}
+
/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
object Unions {
def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
- case u: Union => Some(collectUnionChildren(u))
+ case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
case _ => None
}
- private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
- case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r)
- case other => other :: Nil
+ // Doing a depth-first tree traversal to combine all the union children.
+ @tailrec
+ private def collectUnionChildren(
+ plans: mutable.Stack[LogicalPlan],
+ children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+ if (plans.isEmpty) children
+ else {
+ plans.pop match {
+ case Union(grandchildren) =>
+ grandchildren.reverseMap(plans.push(_))
+ collectUnionChildren(plans, children)
+ case other => collectUnionChildren(plans, children :+ other)
+ }
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index f4a3d85d2a..e9c970cd08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -101,19 +101,6 @@ private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}
-case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
-
- override def output: Seq[Attribute] =
- left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
- leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
- }
-
- override def statistics: Statistics = {
- val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
- Statistics(sizeInBytes = sizeInBytes)
- }
-}
-
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
override def output: Seq[Attribute] =
@@ -127,6 +114,40 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override def output: Seq[Attribute] = left.output
}
+/** Factory for constructing new `Union` nodes. */
+object Union {
+ def apply(left: LogicalPlan, right: LogicalPlan): Union = {
+ Union (left :: right :: Nil)
+ }
+}
+
+case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
+
+ // updating nullability to make all the children consistent
+ override def output: Seq[Attribute] =
+ children.map(_.output).transpose.map(attrs =>
+ attrs.head.withNullability(attrs.exists(_.nullable)))
+
+ override lazy val resolved: Boolean = {
+ // allChildrenCompatible needs to be evaluated after childrenResolved
+ def allChildrenCompatible: Boolean =
+ children.tail.forall( child =>
+ // compare the attribute number with the first child
+ child.output.length == children.head.output.length &&
+ // compare the data types with the first child
+ child.output.zip(children.head.output).forall {
+ case (l, r) => l.dataType == r.dataType }
+ )
+
+ children.length > 1 && childrenResolved && allChildrenCompatible
+ }
+
+ override def statistics: Statistics = {
+ val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
+ Statistics(sizeInBytes = sizeInBytes)
+ }
+}
+
case class Join(
left: LogicalPlan,
right: LogicalPlan,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 975cd87d09..ab68028220 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
+ test("Eliminate the unnecessary union") {
+ val plan = Union(testRelation :: Nil)
+ val expected = testRelation
+ checkAnalysis(plan, expected)
+ }
+
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 39c8f56c1b..24c608eaa5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer.execute(plan).collect {
- case Union(left, right) => (left.output.head, right.output.head)
+ case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index b326aa9c55..c30434a006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -387,19 +387,19 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
- test("WidenSetOperationTypes for union, except, and intersect") {
- def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
- logical.output.zip(expectTypes).foreach { case (attr, dt) =>
- assert(attr.dataType === dt)
- }
+ private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+ logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+ assert(attr.dataType === dt)
}
+ }
- val left = LocalRelation(
+ test("WidenSetOperationTypes for except and intersect") {
+ val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
- val right = LocalRelation(
+ val secondTable = LocalRelation(
AttributeReference("s", StringType)(),
AttributeReference("d", DecimalType(2, 1))(),
AttributeReference("f", FloatType)(),
@@ -408,15 +408,65 @@ class HiveTypeCoercionSuite extends PlanTest {
val wt = HiveTypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val r1 = wt(Union(left, right)).asInstanceOf[Union]
- val r2 = wt(Except(left, right)).asInstanceOf[Except]
- val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
+ val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
+ val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
checkOutput(r2.right, expectedTypes)
- checkOutput(r3.left, expectedTypes)
- checkOutput(r3.right, expectedTypes)
+
+ // Check if a Project is added
+ assert(r1.left.isInstanceOf[Project])
+ assert(r1.right.isInstanceOf[Project])
+ assert(r2.left.isInstanceOf[Project])
+ assert(r2.right.isInstanceOf[Project])
+
+ val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except]
+ checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
+ checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
+
+ // Check if no Project is added
+ assert(r3.left.isInstanceOf[LocalRelation])
+ assert(r3.right.isInstanceOf[LocalRelation])
+ }
+
+ test("WidenSetOperationTypes for union") {
+ val firstTable = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("b", ByteType)(),
+ AttributeReference("d", DoubleType)())
+ val secondTable = LocalRelation(
+ AttributeReference("s", StringType)(),
+ AttributeReference("d", DecimalType(2, 1))(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("l", LongType)())
+ val thirdTable = LocalRelation(
+ AttributeReference("m", StringType)(),
+ AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("p", FloatType)(),
+ AttributeReference("q", DoubleType)())
+ val forthTable = LocalRelation(
+ AttributeReference("m", StringType)(),
+ AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("p", ByteType)(),
+ AttributeReference("q", DoubleType)())
+
+ val wt = HiveTypeCoercion.WidenSetOperationTypes
+ val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
+
+ val unionRelation = wt(
+ Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
+ assert(unionRelation.children.length == 4)
+ checkOutput(unionRelation.children.head, expectedTypes)
+ checkOutput(unionRelation.children(1), expectedTypes)
+ checkOutput(unionRelation.children(2), expectedTypes)
+ checkOutput(unionRelation.children(3), expectedTypes)
+
+ assert(unionRelation.children.head.isInstanceOf[Project])
+ assert(unionRelation.children(1).isInstanceOf[Project])
+ assert(unionRelation.children(2).isInstanceOf[Project])
+ assert(unionRelation.children(3).isInstanceOf[Project])
}
test("Transform Decimal precision/scale for union except and intersect") {
@@ -438,8 +488,8 @@ class HiveTypeCoercionSuite extends PlanTest {
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
- checkOutput(r1.left, expectedType1)
- checkOutput(r1.right, expectedType1)
+ checkOutput(r1.children.head, expectedType1)
+ checkOutput(r1.children.last, expectedType1)
checkOutput(r2.left, expectedType1)
checkOutput(r2.right, expectedType1)
checkOutput(r3.left, expectedType1)
@@ -459,7 +509,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
- checkOutput(r1.right, Seq(expectedType))
+ checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))
@@ -467,7 +517,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
- checkOutput(r4.left, Seq(expectedType))
+ checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
checkOutput(r6.left, Seq(expectedType))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index a498b463a6..2283f7c008 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -24,48 +24,73 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-class SetOperationPushDownSuite extends PlanTest {
+class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
+ CombineUnions,
SetOperationPushDown,
SimplifyFilters) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
- val testUnion = Union(testRelation, testRelation2)
+ val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
+ val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
val testIntersect = Intersect(testRelation, testRelation2)
val testExcept = Except(testRelation, testRelation2)
- test("union/intersect/except: filter to each side") {
- val unionQuery = testUnion.where('a === 1)
+ test("union: combine unions into one unions") {
+ val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
+ val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation))
+ val unionOptimized1 = Optimize.execute(unionQuery1.analyze)
+ val unionOptimized2 = Optimize.execute(unionQuery2.analyze)
+
+ comparePlans(unionOptimized1, unionOptimized2)
+
+ val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil)
+ val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze)
+ val unionQuery3 = Union(unionQuery1, unionQuery2)
+ val unionOptimized3 = Optimize.execute(unionQuery3.analyze)
+ comparePlans(combinedUnionsOptimized, unionOptimized3)
+ }
+
+ test("intersect/except: filter to each side") {
val intersectQuery = testIntersect.where('b < 10)
val exceptQuery = testExcept.where('c >= 5)
- val unionOptimized = Optimize.execute(unionQuery.analyze)
val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
- val unionCorrectAnswer =
- Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze
val intersectCorrectAnswer =
Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
val exceptCorrectAnswer =
Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
- comparePlans(unionOptimized, unionCorrectAnswer)
comparePlans(intersectOptimized, intersectCorrectAnswer)
comparePlans(exceptOptimized, exceptCorrectAnswer)
}
+ test("union: filter to each side") {
+ val unionQuery = testUnion.where('a === 1)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Union(testRelation.where('a === 1) ::
+ testRelation2.where('d === 1) ::
+ testRelation3.where('g === 1) :: Nil).analyze
+
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
test("union: project to each side") {
val unionQuery = testUnion.select('a)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
- Union(testRelation.select('a), testRelation2.select('d)).analyze
+ Union(testRelation.select('a) ::
+ testRelation2.select('d) ::
+ testRelation3.select('g) :: Nil).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 95e5fbb119..518f9dcf94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
@@ -1002,7 +1003,9 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def unionAll(other: DataFrame): DataFrame = withPlan {
- Union(logicalPlan, other.logicalPlan)
+ // This breaks caching, but it's usually ok because it addresses a very specific use case:
+ // using union to union many files or partitions.
+ CombineUnions(Union(logicalPlan, other.logicalPlan))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9a9f7d111c..bd99c39957 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
-import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
@@ -603,7 +604,11 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analogous to `UNION ALL` in SQL.
* @since 1.6.0
*/
- def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
+ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) =>
+ // This breaks caching, but it's usually ok because it addresses a very specific use case:
+ // using union to union many files or partitions.
+ CombineUnions(Union(left, right))
+ }
/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c4ddb6d76b..60fbb595e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
- case Unions(unionChildren) =>
+ case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
execution.Except(planLater(left), planLater(right)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 9e2e0357c6..6deb72adad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -281,13 +281,10 @@ case class Range(
* Union two plans, without a distinct. This is UNION ALL in SQL.
*/
case class Union(children: Seq[SparkPlan]) extends SparkPlan {
- override def output: Seq[Attribute] = {
- children.tail.foldLeft(children.head.output) { case (currentOutput, child) =>
- currentOutput.zip(child.output).map { case (a1, a2) =>
- a1.withNullability(a1.nullable || a2.nullable)
- }
- }
- }
+ override def output: Seq[Attribute] =
+ children.map(_.output).transpose.map(attrs =>
+ attrs.head.withNullability(attrs.exists(_.nullable)))
+
protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 1a3df1b117..3c0f25a5dc 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
- Dataset<String> unioned = ds.union(ds2);
+ Dataset<String> unioned = ds.union(ds2).union(ds);
Assert.assertEquals(
- Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"),
+ Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"),
unioned.collectAsList());
Dataset<String> subtracted = ds.subtract(ds2);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index bd11a387a1..09bbe57a43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -25,7 +25,7 @@ import scala.util.Random
import org.scalatest.Matchers._
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions._
@@ -98,6 +98,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.collect().toSeq)
}
+ test("union all") {
+ val unionDF = testData.unionAll(testData).unionAll(testData)
+ .unionAll(testData).unionAll(testData)
+
+ // Before optimizer, Union should be combined.
+ assert(unionDF.queryExecution.analyzed.collect {
+ case j: Union if j.children.size == 5 => j }.size === 1)
+
+ checkAnswer(
+ unionDF.agg(avg('key), max('key), min('key), sum('key)),
+ Row(50.5, 100, 1, 25250) :: Nil
+ )
+ }
+
test("empty data frame") {
assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(sqlContext.emptyDataFrame.count() === 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 49feeaf17d..8fca5e2167 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}
- test("unions are collapsed") {
- val planner = sqlContext.planner
- import planner._
- val query = testData.unionAll(testData).unionAll(testData).logicalPlan
- val planned = BasicOperators(query).head
- val logicalUnions = query collect { case u: logical.Union => u }
- val physicalUnions = planned collect { case u: execution.Union => u }
-
- assert(logicalUnions.size === 2)
- assert(physicalUnions.size === 1)
- }
-
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index e83b4bffff..1654594538 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -129,11 +129,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
conditionSQL = condition.sql
} yield s"$childSQL $whereOrHaving $conditionSQL"
- case Union(left, right) =>
- for {
- leftSQL <- toSQL(left)
- rightSQL <- toSQL(right)
- } yield s"$leftSQL UNION ALL $rightSQL"
+ case Union(children) if children.length > 1 =>
+ val childrenSql = children.map(toSQL(_))
+ if (childrenSql.exists(_.isEmpty)) {
+ None
+ } else {
+ Some(childrenSql.map(_.get).mkString(" UNION ALL "))
+ }
// Persisted data source relation
case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index 0604d9f47c..261a4746f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -105,6 +105,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
}
+ test("three-child union") {
+ checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0")
+ }
+
test("case") {
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
}