aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala127
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala202
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala97
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala2
9 files changed, 436 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7b451baaa0..899227674f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -313,6 +313,8 @@ trait CheckAnalysis extends PredicateHelper {
|${s.catalogTable.identifier}
""".stripMargin)
+ // TODO: We need to consolidate this kind of checks for InsertIntoTable
+ // with the rule of PreWriteCheck defined in extendedCheckRules.
case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) =>
failAnalysis(
s"""
@@ -320,6 +322,23 @@ trait CheckAnalysis extends PredicateHelper {
|${s.catalogTable.identifier}
""".stripMargin)
+ case InsertIntoTable(t, _, _, _, _)
+ if !t.isInstanceOf[LeafNode] ||
+ t == OneRowRelation ||
+ t.isInstanceOf[LocalRelation] =>
+ failAnalysis(s"Inserting into an RDD-based table is not allowed.")
+
+ case i @ InsertIntoTable(table, partitions, query, _, _) =>
+ val numStaticPartitions = partitions.values.count(_.isDefined)
+ if (table.output.size != (query.output.size + numStaticPartitions)) {
+ failAnalysis(
+ s"$table requires that the data to be inserted have the same number of " +
+ s"columns as the target table: target table has ${table.output.size} " +
+ s"column(s) but the inserted data has " +
+ s"${query.output.size + numStaticPartitions} column(s), including " +
+ s"$numStaticPartitions partition column(s) having constant value(s).")
+ }
+
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2b4786542c..27133f0a43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -22,15 +22,15 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.DataSourceScanExec.PUSHED_FILTERS
@@ -43,8 +43,127 @@ import org.apache.spark.unsafe.types.UTF8String
* Replaces generic operations with specific variants that are designed to work with Spark
* SQL Data Sources.
*/
-private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] {
+private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
+
+ def resolver: Resolver = {
+ if (conf.caseSensitiveAnalysis) {
+ caseSensitiveResolution
+ } else {
+ caseInsensitiveResolution
+ }
+ }
+
+ // The access modifier is used to expose this method to tests.
+ private[sql] def convertStaticPartitions(
+ sourceAttributes: Seq[Attribute],
+ providedPartitions: Map[String, Option[String]],
+ targetAttributes: Seq[Attribute],
+ targetPartitionSchema: StructType): Seq[NamedExpression] = {
+
+ assert(providedPartitions.exists(_._2.isDefined))
+
+ val staticPartitions = providedPartitions.flatMap {
+ case (partKey, Some(partValue)) => (partKey, partValue) :: Nil
+ case (_, None) => Nil
+ }
+
+ // The sum of the number of static partition columns and columns provided in the SELECT
+ // clause needs to match the number of columns of the target table.
+ if (staticPartitions.size + sourceAttributes.size != targetAttributes.size) {
+ throw new AnalysisException(
+ s"The data to be inserted needs to have the same number of " +
+ s"columns as the target table: target table has ${targetAttributes.size} " +
+ s"column(s) but the inserted data has ${sourceAttributes.size + staticPartitions.size} " +
+ s"column(s), which contain ${staticPartitions.size} partition column(s) having " +
+ s"assigned constant values.")
+ }
+
+ if (providedPartitions.size != targetPartitionSchema.fields.size) {
+ throw new AnalysisException(
+ s"The data to be inserted needs to have the same number of " +
+ s"partition columns as the target table: target table " +
+ s"has ${targetPartitionSchema.fields.size} partition column(s) but the inserted " +
+ s"data has ${providedPartitions.size} partition columns specified.")
+ }
+
+ staticPartitions.foreach {
+ case (partKey, partValue) =>
+ if (!targetPartitionSchema.fields.exists(field => resolver(field.name, partKey))) {
+ throw new AnalysisException(
+ s"$partKey is not a partition column. Partition columns are " +
+ s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}")
+ }
+ }
+
+ val partitionList = targetPartitionSchema.fields.map { field =>
+ val potentialSpecs = staticPartitions.filter {
+ case (partKey, partValue) => resolver(field.name, partKey)
+ }
+ if (potentialSpecs.size == 0) {
+ None
+ } else if (potentialSpecs.size == 1) {
+ val partValue = potentialSpecs.head._2
+ Some(Alias(Cast(Literal(partValue), field.dataType), "_staticPart")())
+ } else {
+ throw new AnalysisException(
+ s"Partition column ${field.name} have multiple values specified, " +
+ s"${potentialSpecs.mkString("[", ", ", "]")}. Please only specify a single value.")
+ }
+ }
+
+ // We first drop all leading static partitions using dropWhile and check if there is
+ // any static partition appear after dynamic partitions.
+ partitionList.dropWhile(_.isDefined).collectFirst {
+ case Some(_) =>
+ throw new AnalysisException(
+ s"The ordering of partition columns is " +
+ s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}. " +
+ "All partition columns having constant values need to appear before other " +
+ "partition columns that do not have an assigned constant value.")
+ }
+
+ assert(partitionList.take(staticPartitions.size).forall(_.isDefined))
+ val projectList =
+ sourceAttributes.take(targetAttributes.size - targetPartitionSchema.fields.size) ++
+ partitionList.take(staticPartitions.size).map(_.get) ++
+ sourceAttributes.takeRight(targetPartitionSchema.fields.size - staticPartitions.size)
+
+ projectList
+ }
+
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // If the InsertIntoTable command is for a partitioned HadoopFsRelation and
+ // the user has specified static partitions, we add a Project operator on top of the query
+ // to include those constant column values in the query result.
+ //
+ // Example:
+ // Let's say that we have a table "t", which is created by
+ // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)
+ // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3"
+ // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3".
+ //
+ // Basically, we will put those partition columns having a assigned value back
+ // to the SELECT clause. The output of the SELECT clause is organized as
+ // normal_columns static_partitioning_columns dynamic_partitioning_columns.
+ // static_partitioning_columns are partitioning columns having assigned
+ // values in the PARTITION clause (e.g. b in the above example).
+ // dynamic_partitioning_columns are partitioning columns that do not assigned
+ // values in the PARTITION clause (e.g. c in the above example).
+ case insert @ logical.InsertIntoTable(
+ relation @ LogicalRelation(t: HadoopFsRelation, _, _), parts, query, overwrite, false)
+ if query.resolved && parts.exists(_._2.isDefined) =>
+
+ val projectList = convertStaticPartitions(
+ sourceAttributes = query.output,
+ providedPartitions = parts,
+ targetAttributes = relation.output,
+ targetPartitionSchema = t.partitionSchema)
+
+ // We will remove all assigned values to static partitions because they have been
+ // moved to the projectList.
+ insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query))
+
+
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false)
if query.resolved && t.schema.asNullable == query.schema.asNullable =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 10425af3e1..15b9d14bd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -206,13 +206,6 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
// The relation in l is not an InsertableRelation.
failAnalysis(s"$l does not allow insertion.")
- case logical.InsertIntoTable(t, _, _, _, _) =>
- if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) {
- failAnalysis(s"Inserting into an RDD-based table is not allowed.")
- } else {
- // OK
- }
-
case c: CreateTableUsingAsSelect =>
// When the SaveMode is Overwrite, we need to check if the table is an input table of
// the query. If so, we will throw an AnalysisException to let users know it is not allowed.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 5300cfa8a7..5f5cf5c6d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -113,7 +113,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
override val extendedResolutionRules =
PreprocessTableInsertion(conf) ::
new FindDataSourceTable(sparkSession) ::
- DataSourceAnalysis ::
+ DataSourceAnalysis(conf) ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
new file mode 100644
index 0000000000..448adcf11d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal}
+import org.apache.spark.sql.execution.datasources.DataSourceAnalysis
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+ private var targetAttributes: Seq[Attribute] = _
+ private var targetPartitionSchema: StructType = _
+
+ override def beforeAll(): Unit = {
+ targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int)
+ targetPartitionSchema = new StructType()
+ .add("b", IntegerType)
+ .add("c", IntegerType)
+ }
+
+ private def checkProjectList(actual: Seq[Expression], expected: Seq[Expression]): Unit = {
+ // Remove aliases since we have no control on their exprId.
+ val withoutAliases = actual.map {
+ case alias: Alias => alias.child
+ case other => other
+ }
+ assert(withoutAliases === expected)
+ }
+
+ Seq(true, false).foreach { caseSensitive =>
+ val rule = DataSourceAnalysis(SimpleCatalystConf(caseSensitive))
+ test(
+ s"convertStaticPartitions only handle INSERT having at least static partitions " +
+ s"(caseSensitive: $caseSensitive)") {
+ intercept[AssertionError] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> None, "c" -> None),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+ }
+
+ test(s"Missing columns (caseSensitive: $caseSensitive)") {
+ // Missing columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int),
+ providedPartitions = Map("b" -> Some("1"), "c" -> None),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+ }
+
+ test(s"Missing partitioning columns (caseSensitive: $caseSensitive)") {
+ // Missing partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> Some("1")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+
+ // Missing partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int, 'g.int),
+ providedPartitions = Map("b" -> Some("1")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+
+ // Wrong partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> Some("1"), "d" -> None),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+ }
+
+ test(s"Wrong partitioning columns (caseSensitive: $caseSensitive)") {
+ // Wrong partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> Some("1"), "d" -> Some("2")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+
+ // Wrong partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int),
+ providedPartitions = Map("b" -> Some("1"), "c" -> Some("3"), "d" -> Some("2")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+
+ if (caseSensitive) {
+ // Wrong partitioning columns.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+ }
+ }
+
+ test(
+ s"Static partitions need to appear before dynamic partitions" +
+ s" (caseSensitive: $caseSensitive)") {
+ // Static partitions need to appear before dynamic partitions.
+ intercept[AnalysisException] {
+ rule.convertStaticPartitions(
+ sourceAttributes = Seq('e.int, 'f.int),
+ providedPartitions = Map("b" -> None, "c" -> Some("3")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ }
+ }
+
+ test(s"All static partitions (caseSensitive: $caseSensitive)") {
+ if (!caseSensitive) {
+ val nonPartitionedAttributes = Seq('e.int, 'f.int)
+ val expected = nonPartitionedAttributes ++
+ Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ val actual = rule.convertStaticPartitions(
+ sourceAttributes = nonPartitionedAttributes,
+ providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ checkProjectList(actual, expected)
+ }
+
+ {
+ val nonPartitionedAttributes = Seq('e.int, 'f.int)
+ val expected = nonPartitionedAttributes ++
+ Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ val actual = rule.convertStaticPartitions(
+ sourceAttributes = nonPartitionedAttributes,
+ providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ checkProjectList(actual, expected)
+ }
+
+ // Test the case having a single static partition column.
+ {
+ val nonPartitionedAttributes = Seq('e.int, 'f.int)
+ val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType))
+ val actual = rule.convertStaticPartitions(
+ sourceAttributes = nonPartitionedAttributes,
+ providedPartitions = Map("b" -> Some("1")),
+ targetAttributes = Seq('a.int, 'd.int, 'b.int),
+ targetPartitionSchema = new StructType().add("b", IntegerType))
+ checkProjectList(actual, expected)
+ }
+ }
+
+ test(s"Static partition and dynamic partition (caseSensitive: $caseSensitive)") {
+ val nonPartitionedAttributes = Seq('e.int, 'f.int)
+ val dynamicPartitionAttributes = Seq('g.int)
+ val expected =
+ nonPartitionedAttributes ++
+ Seq(Cast(Literal("1"), IntegerType)) ++
+ dynamicPartitionAttributes
+ val actual = rule.convertStaticPartitions(
+ sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes,
+ providedPartitions = Map("b" -> Some("1"), "c" -> None),
+ targetAttributes = targetAttributes,
+ targetPartitionSchema = targetPartitionSchema)
+ checkProjectList(actual, expected)
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index f6675f0904..8773993d36 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -66,7 +66,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
catalog.OrcConversions ::
catalog.CreateTables ::
PreprocessTableInsertion(conf) ::
- DataSourceAnalysis ::
+ DataSourceAnalysis(conf) ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index b3896484da..97cd29f541 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -33,6 +33,7 @@ import org.apache.hadoop.hive.ql.ErrorMsg
import org.apache.hadoop.mapred.{FileOutputFormat, JobConf}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
@@ -196,7 +197,7 @@ case class InsertIntoHiveTable(
// Report error if any static partition appears after a dynamic partition
val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty)
if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) {
- throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
+ throw new AnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index d4ebd051d2..46432512ba 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.hive
import java.io.File
-import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.{QueryTest, _}
+import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -331,7 +331,11 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
withTable(hiveTable) {
withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
- sql(s"CREATE TABLE $hiveTable (a INT) PARTITIONED BY (b INT, c INT) STORED AS TEXTFILE")
+ sql(
+ s"""
+ |CREATE TABLE $hiveTable (a INT, d INT)
+ |PARTITIONED BY (b INT, c INT) STORED AS TEXTFILE
+ """.stripMargin)
f(hiveTable)
}
}
@@ -343,7 +347,11 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
val dsTable = "ds_table"
withTable(dsTable) {
- sql(s"CREATE TABLE $dsTable (a INT, b INT, c INT) USING PARQUET PARTITIONED BY (b, c)")
+ sql(
+ s"""
+ |CREATE TABLE $dsTable (a INT, b INT, c INT, d INT)
+ |USING PARQUET PARTITIONED BY (b, c)
+ """.stripMargin)
f(dsTable)
}
}
@@ -356,7 +364,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
testPartitionedTable("partitionBy() can't be used together with insertInto()") { tableName =>
val cause = intercept[AnalysisException] {
- Seq((1, 2, 3)).toDF("a", "b", "c").write.partitionBy("b", "c").insertInto(tableName)
+ Seq((1, 2, 3, 4)).toDF("a", "b", "c", "d").write.partitionBy("b", "c").insertInto(tableName)
}
assert(cause.getMessage.contains("insertInto() can't be used together with partitionBy()."))
@@ -382,14 +390,83 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
assert(e.message.contains("the number of columns are different"))
}
- testPartitionedTable(
- "SPARK-16037: INSERT statement should match columns by position") {
+ testPartitionedTable("SPARK-16037: INSERT statement should match columns by position") {
tableName =>
withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
- sql(s"INSERT INTO TABLE $tableName SELECT 1, 2 AS c, 3 AS b")
- checkAnswer(sql(s"SELECT a, b, c FROM $tableName"), Row(1, 2, 3))
- sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1, 2, 3")
- checkAnswer(sql(s"SELECT a, b, c FROM $tableName"), Row(1, 2, 3))
+ sql(s"INSERT INTO TABLE $tableName SELECT 1, 4, 2 AS c, 3 AS b")
+ checkAnswer(sql(s"SELECT a, b, c, d FROM $tableName"), Row(1, 2, 3, 4))
+ sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1, 4, 2, 3")
+ checkAnswer(sql(s"SELECT a, b, c, 4 FROM $tableName"), Row(1, 2, 3, 4))
+ }
+ }
+
+ testPartitionedTable("INSERT INTO a partitioned table (semantic and error handling)") {
+ tableName =>
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=2, c=3) SELECT 1, 4")
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=6, c=7) SELECT 5, 8")
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12")
+
+ // c is defined twice. Parser will complain.
+ intercept[ParseException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13")
+ }
+
+ // d is not a partitioning column.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13, 14")
+ }
+
+ // d is not a partitioning column. The total number of columns is correct.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13")
+ }
+
+ // The data is missing a column.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13")
+ }
+
+ // d is not a partitioning column.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=15, d=15) SELECT 13, 14")
+ }
+
+ // The statement is missing a column.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=15) SELECT 13, 14")
+ }
+
+ // The statement is missing a column.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=15) SELECT 13, 14, 16")
+ }
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c) SELECT 13, 16, 15")
+
+ // Dynamic partitioning columns need to be after static partitioning columns.
+ intercept[AnalysisException] {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b, c=19) SELECT 17, 20, 18")
+ }
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (b, c) SELECT 17, 20, 18, 19")
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (c, b) SELECT 21, 24, 22, 23")
+
+ sql(s"INSERT INTO TABLE $tableName SELECT 25, 28, 26, 27")
+
+ checkAnswer(
+ sql(s"SELECT a, b, c, d FROM $tableName"),
+ Row(1, 2, 3, 4) ::
+ Row(5, 6, 7, 8) ::
+ Row(9, 10, 11, 12) ::
+ Row(13, 14, 15, 16) ::
+ Row(17, 18, 19, 20) ::
+ Row(21, 22, 23, 24) ::
+ Row(25, 26, 27, 28) :: Nil
+ )
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 85b159e2a5..f8c55ec456 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -1006,7 +1006,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
sql("SET hive.exec.dynamic.partition.mode=nonstrict")
// Should throw when a static partition appears after a dynamic partition
- intercept[SparkException] {
+ intercept[AnalysisException] {
sql(
"""INSERT INTO TABLE dp_test PARTITION(dp, sp = 1)
|SELECT key, value, key % 5 FROM src