diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-06-11 15:47:51 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-06-11 15:47:51 -0700 |
commit | 3fd2ff4dd85633af49865456a52bf0c09c99708b (patch) | |
tree | 79e54d89347a1e8cea18aab8654c3fe99109f023 /sql/core | |
parent | c06c58bbbb2de0c22cfc70c486d23a94c3079ba4 (diff) | |
download | spark-3fd2ff4dd85633af49865456a52bf0c09c99708b.tar.gz spark-3fd2ff4dd85633af49865456a52bf0c09c99708b.tar.bz2 spark-3fd2ff4dd85633af49865456a52bf0c09c99708b.zip |
[SPARK-15807][SQL] Support varargs for dropDuplicates in Dataset/DataFrame
## What changes were proposed in this pull request?
This PR adds `varargs`-types `dropDuplicates` functions in `Dataset/DataFrame`. Currently, `dropDuplicates` supports only `Seq` or `Array`.
**Before**
```scala
scala> val ds = spark.createDataFrame(Seq(("a", 1), ("b", 2), ("a", 2)))
ds: org.apache.spark.sql.DataFrame = [_1: string, _2: int]
scala> ds.dropDuplicates(Seq("_1", "_2"))
res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [_1: string, _2: int]
scala> ds.dropDuplicates("_1", "_2")
<console>:26: error: overloaded method value dropDuplicates with alternatives:
(colNames: Array[String])org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] <and>
(colNames: Seq[String])org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] <and>
()org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]
cannot be applied to (String, String)
ds.dropDuplicates("_1", "_2")
^
```
**After**
```scala
scala> val ds = spark.createDataFrame(Seq(("a", 1), ("b", 2), ("a", 2)))
ds: org.apache.spark.sql.DataFrame = [_1: string, _2: int]
scala> ds.dropDuplicates("_1", "_2")
res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [_1: string, _2: int]
```
## How was this patch tested?
Pass the Jenkins tests with new testcases.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #13545 from dongjoon-hyun/SPARK-15807.
Diffstat (limited to 'sql/core')
3 files changed, 30 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 16bbf30a94..5a67fc79ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1834,6 +1834,19 @@ class Dataset[T] private[sql]( def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq) /** + * Returns a new [[Dataset]] with duplicate rows removed, considering only + * the subset of columns. + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def dropDuplicates(col1: String, cols: String*): Dataset[T] = { + val colNames: Seq[String] = col1 +: cols + dropDuplicates(colNames) + } + + /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. * If no columns are given, this function computes statistics for all numerical columns. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a02e48d849..6bb0ce95c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -906,6 +906,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.dropDuplicates(Seq("value2")), Seq(Row(2, 1, 2), Row(1, 1, 1))) + + checkAnswer( + testData.dropDuplicates("key", "value1"), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) } test("SPARK-7150 range api") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 11b52bdead..4536a7356f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -806,6 +806,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains("Null value appeared in non-nullable field")) assert(e.getMessage.contains("top level non-flat input object")) } + + test("dropDuplicates") { + val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + checkDataset( + ds.dropDuplicates("_1"), + ("a", 1), ("b", 1)) + checkDataset( + ds.dropDuplicates("_2"), + ("a", 1), ("a", 2)) + checkDataset( + ds.dropDuplicates("_1", "_2"), + ("a", 1), ("a", 2), ("b", 1)) + } } case class Generic[T](id: T, value: Double) |