From 7dc4965f34e37b37f4fab69859fcce6476f87811 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Feb 2015 23:02:40 -0800 Subject: [SPARK-5639][SQL] Support DataFrame.renameColumn. Author: Reynold Xin Closes #4410 from rxin/df-renameCol and squashes the following commits: a6a796e [Reynold Xin] [SPARK-5639][SQL] Support DataFrame.renameColumn. --- .../main/scala/org/apache/spark/sql/DataFrame.scala | 9 ++++++++- .../scala/org/apache/spark/sql/DataFrameImpl.scala | 8 ++++++++ .../org/apache/spark/sql/IncomputableColumn.scala | 2 ++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++++++++++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 92e04ce17c..8ad6526f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,7 +36,8 @@ private[sql] object DataFrame { /** - * A collection of rows that have the same columns. + * :: Experimental :: + * A distributed collection of data organized into named columns. * * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using * various functions in [[SQLContext]]. @@ -72,6 +73,7 @@ private[sql] object DataFrame { * }}} */ // TODO: Improve documentation. +@Experimental trait DataFrame extends RDDApi[Row] { val sqlContext: SQLContext @@ -425,6 +427,11 @@ trait DataFrame extends RDDApi[Row] { */ def addColumn(colName: String, col: Column): DataFrame + /** + * Returns a new [[DataFrame]] with a column renamed. + */ + def renameColumn(existingName: String, newName: String): DataFrame + /** * Returns the first `n` rows. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 4911443dd6..789bcf6184 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -238,6 +238,14 @@ private[sql] class DataFrameImpl protected[sql]( select(Column("*"), col.as(colName)) } + override def renameColumn(existingName: String, newName: String): DataFrame = { + val colNames = schema.map { field => + val name = field.name + if (name == existingName) Column(name).as(newName) else Column(name) + } + select(colNames :_*) + } + override def head(n: Int): Array[Row] = limit(n).collect() override def head(): Row = head(1).head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index fedd7f06ef..6043fb4dee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -108,6 +108,8 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def addColumn(colName: String, col: Column): DataFrame = err() + override def renameColumn(existingName: String, newName: String): DataFrame = err() + override def head(n: Int): Array[Row] = err() override def head(): Row = err() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 77fd3165f1..5aa3db720c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -313,6 +313,27 @@ class DataFrameSuite extends QueryTest { ) } + test("addColumn") { + val df = testData.toDataFrame.addColumn("newCol", col("key") + 1) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + } + + test("renameColumn") { + val df = testData.toDataFrame.addColumn("newCol", col("key") + 1) + .renameColumn("value", "valueRenamed") + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df("key"), testData.select('key).collect().toSeq) -- cgit v1.2.3