aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-08 21:01:53 -0800
committerYin Huai <yhuai@databricks.com>2015-11-08 21:01:53 -0800
commitd8b50f70298dbf45e91074ee2d751fee7eecb119 (patch)
treead2b1418e3684630bd0ac18349e9c559bbf4782c /sql
parent97b7080cf2d2846c7257f8926f775f27d457fe7d (diff)
downloadspark-d8b50f70298dbf45e91074ee2d751fee7eecb119.tar.gz
spark-d8b50f70298dbf45e91074ee2d751fee7eecb119.tar.bz2
spark-d8b50f70298dbf45e91074ee2d751fee7eecb119.zip
[SPARK-11453][SQL] append data to partitioned table will messes up the result
The reason is that: 1. For partitioned hive table, we will move the partitioned columns after data columns. (e.g. `<a: Int, b: Int>` partition by `a` will become `<b: Int, a: Int>`) 2. When append data to table, we use position to figure out how to match input columns to table's columns. So when we append data to partitioned table, we will match wrong columns between input and table. A solution is reordering the input columns before match by position, like what we did for [`InsertIntoHadoopFsRelation`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala#L101-L105) Author: Wenchen Fan <wenchen@databricks.com> Closes #9408 from cloud-fan/append.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala20
3 files changed, 53 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 7887e559a3..e63a4d5e8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -23,8 +23,8 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.sources.HadoopFsRelation
@@ -167,17 +167,38 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
- val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap)
+ val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap)
val overwrite = mode == SaveMode.Overwrite
+
+ // A partitioned relation's schema can be different from the input logicalPlan, since
+ // partition columns are all moved after data columns. We Project to adjust the ordering.
+ // TODO: this belongs to the analyzer.
+ val input = normalizedParCols.map { parCols =>
+ val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr =>
+ parCols.contains(attr.name)
+ }
+ Project(inputDataCols ++ inputPartCols, df.logicalPlan)
+ }.getOrElse(df.logicalPlan)
+
df.sqlContext.executePlan(
InsertIntoTable(
UnresolvedRelation(tableIdent),
partitions.getOrElse(Map.empty[String, Option[String]]),
- df.logicalPlan,
+ input,
overwrite,
ifNotExists = false)).toRdd
}
+ private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols =>
+ parCols.map { col =>
+ df.logicalPlan.output
+ .map(_.name)
+ .find(df.sqlContext.analyzer.resolver(_, col))
+ .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " +
+ s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})"))
+ }
+ }
+
/**
* Saves the content of the [[DataFrame]] as the specified table.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index c9791879ec..3eaa817f9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -53,4 +53,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
Utils.deleteRecursively(path)
}
+
+ test("partitioned columns should appear at the end of schema") {
+ withTempPath { f =>
+ val path = f.getAbsolutePath
+ Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
+ assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index af48d47895..9a425d7f6b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1428,4 +1428,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a"))
}
}
+
+ test("SPARK-11453: append data to partitioned table") {
+ withTable("tbl11453") {
+ Seq("1" -> "10", "2" -> "20").toDF("i", "j")
+ .write.partitionBy("i").saveAsTable("tbl11453")
+
+ Seq("3" -> "30").toDF("i", "j")
+ .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453")
+ checkAnswer(
+ sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"),
+ Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil)
+
+ // make sure case sensitivity is correct.
+ Seq("4" -> "40").toDF("i", "j")
+ .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453")
+ checkAnswer(
+ sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"),
+ Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil)
+ }
+ }
}