aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-08-25 00:04:10 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-25 00:04:10 -0700
commit82268f07abfa658869df2354ae72f8d6ddd119e8 (patch)
treeffdb8d008d0314ab648412db96cb5cc1b65c5467 /sql
parentbf03fe68d62f33dda70dff45c3bda1f57b032dfc (diff)
downloadspark-82268f07abfa658869df2354ae72f8d6ddd119e8.tar.gz
spark-82268f07abfa658869df2354ae72f8d6ddd119e8.tar.bz2
spark-82268f07abfa658869df2354ae72f8d6ddd119e8.zip
[SPARK-9293] [SPARK-9813] Analysis should check that set operations are only performed on tables with equal numbers of columns
This patch adds an analyzer rule to ensure that set operations (union, intersect, and except) are only applied to tables with the same number of columns. Without this rule, there are scenarios where invalid queries can return incorrect results instead of failing with error messages; SPARK-9813 provides one example of this problem. In other cases, the invalid query can crash at runtime with extremely confusing exceptions. I also performed a bit of cleanup to refactor some of those logical operators' code into a common `SetOperation` base class. Author: Josh Rosen <joshrosen@databricks.com> Closes #7631 from JoshRosen/SPARK-9293.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala2
6 files changed, 48 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 39f554c137..7701fd0451 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -137,6 +137,12 @@ trait CheckAnalysis {
}
}
+ case s @ SetOperation(left, right) if left.output.length != right.output.length =>
+ failAnalysis(
+ s"${s.nodeName} can only be performed on tables with the same number of columns, " +
+ s"but the left table has ${left.output.length} columns and the right has " +
+ s"${right.output.length}")
+
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2cb067f4aa..a1aa2a2b2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -203,6 +203,7 @@ object HiveTypeCoercion {
planName: String,
left: LogicalPlan,
right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
+ require(left.output.length == right.output.length)
val castedTypes = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
@@ -229,15 +230,10 @@ object HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p
- case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
- val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
- Union(newLeft, newRight)
- case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
- val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
- Except(newLeft, newRight)
- case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
- val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
- Intersect(newLeft, newRight)
+ case s @ SetOperation(left, right) if s.childrenResolved
+ && left.output.length == right.output.length && !s.resolved =>
+ val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right)
+ s.makeCopy(Array(newLeft, newRight))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 73b8261260..722f69cdca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -89,13 +89,21 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
-case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
+abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
// TODO: These aren't really the same attributes as nullability etc might change.
- override def output: Seq[Attribute] = left.output
+ final override def output: Seq[Attribute] = left.output
- override lazy val resolved: Boolean =
+ final override lazy val resolved: Boolean =
childrenResolved &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+}
+
+private[sql] object SetOperation {
+ def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
+}
+
+case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
override def statistics: Statistics = {
val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
@@ -103,6 +111,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
}
}
+case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+
+case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+
case class Join(
left: LogicalPlan,
right: LogicalPlan,
@@ -142,15 +154,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
-
-case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- override def output: Seq[Attribute] = left.output
-
- override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}
-
case class InsertIntoTable(
table: LogicalPlan,
partition: Map[String, Option[String]],
@@ -160,7 +163,7 @@ case class InsertIntoTable(
extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = Seq.empty
assert(overwrite || !ifNotExists)
override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall {
@@ -440,10 +443,3 @@ case object OneRowRelation extends LeafNode {
override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
-case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- override def output: Seq[Attribute] = left.output
-
- override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 7065adce04..fbdd3a7776 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -146,6 +146,24 @@ class AnalysisErrorSuite extends AnalysisTest {
"unresolved" :: Nil)
errorTest(
+ "union with unequal number of columns",
+ testRelation.unionAll(testRelation2),
+ "union" :: "number of columns" :: testRelation2.output.length.toString ::
+ testRelation.output.length.toString :: Nil)
+
+ errorTest(
+ "intersect with unequal number of columns",
+ testRelation.intersect(testRelation2),
+ "intersect" :: "number of columns" :: testRelation2.output.length.toString ::
+ testRelation.output.length.toString :: Nil)
+
+ errorTest(
+ "except with unequal number of columns",
+ testRelation.except(testRelation2),
+ "except" :: "number of columns" :: testRelation2.output.length.toString ::
+ testRelation.output.length.toString :: Nil)
+
+ errorTest(
"SPARK-9955: correct error message for aggregate",
// When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias.
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index bbe8c1911b..98d21aa76d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -751,7 +751,7 @@ private[hive] case class InsertIntoHiveTable(
extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = Seq.empty
val numDynamicPartitions = partition.values.count(_.isEmpty)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 12c667e6e9..62efda613a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -61,7 +61,7 @@ case class InsertIntoHiveTable(
serializer
}
- def output: Seq[Attribute] = child.output
+ def output: Seq[Attribute] = Seq.empty
def saveAsHiveFile(
rdd: RDD[InternalRow],