aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala45
3 files changed, 79 insertions, 3 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7673153abe..03b01a1136 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1189,15 +1189,30 @@ class DataFrame(object):
@since(1.4)
@ignore_unicode_prefix
- def drop(self, colName):
+ def drop(self, col):
"""Returns a new :class:`DataFrame` that drops the specified column.
- :param colName: string, name of the column to drop.
+ :param col: a string name of the column to drop, or a
+ :class:`Column` to drop.
>>> df.drop('age').collect()
[Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.drop(df.age).collect()
+ [Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()
+ [Row(age=5, height=85, name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()
+ [Row(age=5, name=u'Bob', height=85)]
"""
- jdf = self._jdf.drop(colName)
+ if isinstance(col, basestring):
+ jdf = self._jdf.drop(col)
+ elif isinstance(col, Column):
+ jdf = self._jdf.drop(col._jc)
+ else:
+ raise TypeError("col should be a string or a Column")
return DataFrame(jdf, self.sql_ctx)
@since(1.3)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 034d887901..d1a54ada7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1083,6 +1083,22 @@ class DataFrame private[sql](
}
/**
+ * Returns a new [[DataFrame]] with a column dropped.
+ * This version of drop accepts a Column rather than a name.
+ * This is a no-op if the DataFrame doesn't have a column
+ * with an equivalent expression.
+ * @group dfops
+ * @since 1.4.1
+ */
+ def drop(col: Column): DataFrame = {
+ val attrs = this.logicalPlan.output
+ val colsAfterDrop = attrs.filter { attr =>
+ attr != col.expr
+ }.map(attr => Column(attr))
+ select(colsAfterDrop : _*)
+ }
+
+ /**
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
* This is an alias for `distinct`.
* @group dfops
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b41b1b77d0..8e81dacb86 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -334,6 +334,51 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name) === Seq("key", "value"))
}
+ test("drop column using drop with column reference") {
+ val col = testData("key")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().map(x => Row(x.getString(1))).toSeq)
+ assert(df.schema.map(_.name) === Seq("value"))
+ }
+
+ test("drop unknown column (no-op) with column reference") {
+ val col = Column("random")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value"))
+ }
+
+ test("drop unknown column with same name (no-op) with column reference") {
+ val col = Column("key")
+ val df = testData.drop(col)
+ checkAnswer(
+ df,
+ testData.collect().toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value"))
+ }
+
+ test("drop column after join with duplicate columns using column reference") {
+ val newSalary = salary.withColumnRenamed("personId", "id")
+ val col = newSalary("id")
+ // this join will result in duplicate "id" columns
+ val joinedDf = person.join(newSalary,
+ person("id") === newSalary("id"), "inner")
+ // remove only the "id" column that was associated with newSalary
+ val df = joinedDf.drop(col)
+ checkAnswer(
+ df,
+ joinedDf.collect().map {
+ case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) =>
+ Row(id, name, age, salary)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
+ assert(df("id") == person("id"))
+ }
+
test("withColumnRenamed") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")