aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala19
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala83
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala2
5 files changed, 143 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9e9a856286..b7884f9b60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
-import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -444,8 +444,43 @@ class Analyzer(
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
- i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u)))
+ case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
+ val table = lookupTableFromCatalog(u)
+ // adding the table's partitions or validate the query's partition info
+ table match {
+ case relation: CatalogRelation if relation.catalogTable.partitionColumns.nonEmpty =>
+ val tablePartitionNames = relation.catalogTable.partitionColumns.map(_.name)
+ if (parts.keys.nonEmpty) {
+ // the query's partitioning must match the table's partitioning
+ // this is set for queries like: insert into ... partition (one = "a", two = <expr>)
+ // TODO: add better checking to pre-inserts to avoid needing this here
+ if (tablePartitionNames.size != parts.keySet.size) {
+ throw new AnalysisException(
+ s"""Requested partitioning does not match the ${u.tableIdentifier} table:
+ |Requested partitions: ${parts.keys.mkString(",")}
+ |Table partitions: ${tablePartitionNames.mkString(",")}""".stripMargin)
+ }
+ // Assume partition columns are correctly placed at the end of the child's output
+ i.copy(table = EliminateSubqueryAliases(table))
+ } else {
+ // Set up the table's partition scheme with all dynamic partitions by moving partition
+ // columns to the end of the column list, in partition order.
+ val (inputPartCols, columns) = child.output.partition { attr =>
+ tablePartitionNames.contains(attr.name)
+ }
+ // All partition columns are dynamic because this InsertIntoTable had no partitioning
+ val partColumns = tablePartitionNames.map { name =>
+ inputPartCols.find(_.name == name).getOrElse(
+ throw new AnalysisException(s"Cannot find partition column $name"))
+ }
+ i.copy(
+ table = EliminateSubqueryAliases(table),
+ partition = tablePartitionNames.map(_ -> None).toMap,
+ child = Project(columns ++ partColumns, child))
+ }
+ case _ =>
+ i.copy(table = EliminateSubqueryAliases(table))
+ }
case u: UnresolvedRelation =>
val table = u.tableIdentifier
if (table.database.isDefined && conf.runSQLonFile &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 8b438e40e6..732b0d7919 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -354,10 +354,23 @@ case class InsertIntoTable(
override def children: Seq[LogicalPlan] = child :: Nil
override def output: Seq[Attribute] = Seq.empty
+ private[spark] lazy val expectedColumns = {
+ if (table.output.isEmpty) {
+ None
+ } else {
+ val numDynamicPartitions = partition.values.count(_.isEmpty)
+ val (partitionColumns, dataColumns) = table.output
+ .partition(a => partition.keySet.contains(a.name))
+ Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions))
+ }
+ }
+
assert(overwrite || !ifNotExists)
- override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall {
- case (childAttr, tableAttr) =>
- DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
+ override lazy val resolved: Boolean = childrenResolved && expectedColumns.forall { expected =>
+ child.output.size == expected.size && child.output.zip(expected).forall {
+ case (childAttr, tableAttr) =>
+ DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
+ }
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 73ccec2ee0..3805674d39 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -168,7 +168,15 @@ case class InsertIntoHiveTable(
// All partition column names in the format of "<column name 1>/<column name 2>/..."
val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns")
- val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull
+ val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty)
+
+ // By this time, the partition map must match the table's partition columns
+ if (partitionColumnNames.toSet != partition.keySet) {
+ throw new SparkException(
+ s"""Requested partitioning does not match the ${table.tableName} table:
+ |Requested partitions: ${partition.keys.mkString(",")}
+ |Table partitions: ${table.partitionKeys.map(_.name).mkString(",")}""".stripMargin)
+ }
// Validate partition spec if there exist any dynamic partitions
if (numDynamicPartitions > 0) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index baf34d1cf0..52aba328de 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -22,9 +22,11 @@ import java.io.File
import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkException
import org.apache.spark.sql.{QueryTest, _}
-import org.apache.spark.sql.execution.QueryExecutionException
+import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -32,11 +34,11 @@ case class TestData(key: Int, value: String)
case class ThreeCloumntable(key: Int, value: String, key1: String)
-class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter {
+class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter
+ with SQLTestUtils {
import hiveContext.implicits._
- import hiveContext.sql
- val testData = hiveContext.sparkContext.parallelize(
+ override lazy val testData = hiveContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
before {
@@ -213,4 +215,77 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE hiveTableWithStructValue")
}
+
+ test("Reject partitioning that does not match table") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd"))
+ .toDF("id", "data", "part")
+
+ intercept[AnalysisException] {
+ // cannot partition by 2 fields when there is only one in the table definition
+ data.write.partitionBy("part", "data").insertInto("partitioned")
+ }
+ }
+ }
+
+ test("Test partition mode = strict") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) {
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd"))
+ .toDF("id", "data", "part")
+
+ intercept[SparkException] {
+ data.write.insertInto("partitioned")
+ }
+ }
+ }
+
+ test("Detect table partitioning") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ sql("CREATE TABLE source (id bigint, data string, part string)")
+ val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF()
+
+ data.write.insertInto("source")
+ checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)
+
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ // this will pick up the output partitioning from the table definition
+ sqlContext.table("source").write.insertInto("partitioned")
+
+ checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq)
+ }
+ }
+
+ test("Detect table partitioning with correct partition order") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ sql("CREATE TABLE source (id bigint, part2 string, part1 string, data string)")
+ val data = (1 to 10).map(i => (i, if ((i % 2) == 0) "even" else "odd", "p", s"data-$i"))
+ .toDF("id", "part2", "part1", "data")
+
+ data.write.insertInto("source")
+ checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)
+
+ // the original data with part1 and part2 at the end
+ val expected = data.select("id", "data", "part1", "part2")
+
+ sql(
+ """CREATE TABLE partitioned (id bigint, data string)
+ |PARTITIONED BY (part1 string, part2 string)""".stripMargin)
+ sqlContext.table("source").write.insertInto("partitioned")
+
+ checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq)
+ }
+ }
+
+ test("InsertIntoTable#resolved should include dynamic partitions") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data")
+
+ val logical = InsertIntoTable(sqlContext.table("partitioned").logicalPlan,
+ Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false)
+ assert(!logical.resolved, "Should not resolve: missing partition data")
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 3bf0e84267..bbb775ef77 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -978,7 +978,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
sql("SET hive.exec.dynamic.partition.mode=strict")
// Should throw when using strict dynamic partition mode without any static partition
- intercept[SparkException] {
+ intercept[AnalysisException] {
sql(
"""INSERT INTO TABLE dp_test PARTITION(dp)
|SELECT key, value, key % 5 FROM src