aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorXiu Guo <xguo27@gmail.com>2016-01-04 12:34:04 -0800
committerMichael Armbrust <michael@databricks.com>2016-01-04 12:34:04 -0800
commit573ac55d7469ea2ea7a5979b4d3eea99c98f6560 (patch)
tree02110ab48e2aae29a6873a83e726e98acf39b68c /sql
parent43706bf8bdfe08010bb11848788e0718d15363b3 (diff)
downloadspark-573ac55d7469ea2ea7a5979b4d3eea99c98f6560.tar.gz
spark-573ac55d7469ea2ea7a5979b4d3eea99c98f6560.tar.bz2
spark-573ac55d7469ea2ea7a5979b4d3eea99c98f6560.zip
[SPARK-12512][SQL] support column name with dot in withColumn()
Author: Xiu Guo <xguo27@gmail.com> Closes #10500 from xguo27/SPARK-12512.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
2 files changed, 27 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 965eaa9efe..0763aa4ed9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1171,13 +1171,17 @@ class DataFrame private[sql](
*/
def withColumn(colName: String, col: Column): DataFrame = {
val resolver = sqlContext.analyzer.resolver
- val replaced = schema.exists(f => resolver(f.name, colName))
- if (replaced) {
- val colNames = schema.map { field =>
- val name = field.name
- if (resolver(name, colName)) col.as(colName) else Column(name)
+ val output = queryExecution.analyzed.output
+ val shouldReplace = output.exists(f => resolver(f.name, colName))
+ if (shouldReplace) {
+ val columns = output.map { field =>
+ if (resolver(field.name, colName)) {
+ col.as(colName)
+ } else {
+ Column(field)
+ }
}
- select(colNames : _*)
+ select(columns : _*)
} else {
select(Column("*"), col.as(colName))
}
@@ -1188,13 +1192,17 @@ class DataFrame private[sql](
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
val resolver = sqlContext.analyzer.resolver
- val replaced = schema.exists(f => resolver(f.name, colName))
- if (replaced) {
- val colNames = schema.map { field =>
- val name = field.name
- if (resolver(name, colName)) col.as(colName, metadata) else Column(name)
+ val output = queryExecution.analyzed.output
+ val shouldReplace = output.exists(f => resolver(f.name, colName))
+ if (shouldReplace) {
+ val columns = output.map { field =>
+ if (resolver(field.name, colName)) {
+ col.as(colName, metadata)
+ } else {
+ Column(field)
+ }
}
- select(colNames : _*)
+ select(columns : _*)
} else {
select(Column("*"), col.as(colName, metadata))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ad478b0511..ab02b32f91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1221,4 +1221,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
" _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]")
}
+
+ test("SPARK-12512: support `.` in column name for withColumn()") {
+ val df = Seq("a" -> "b").toDF("col.a", "col.b")
+ checkAnswer(df.select(df("*")), Row("a", "b"))
+ checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b"))
+ checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c"))
+ }
}