aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-06-11 15:47:51 -0700
committerReynold Xin <rxin@databricks.com>2016-06-11 15:47:51 -0700
commit3fd2ff4dd85633af49865456a52bf0c09c99708b (patch)
tree79e54d89347a1e8cea18aab8654c3fe99109f023
parentc06c58bbbb2de0c22cfc70c486d23a94c3079ba4 (diff)
downloadspark-3fd2ff4dd85633af49865456a52bf0c09c99708b.tar.gz
spark-3fd2ff4dd85633af49865456a52bf0c09c99708b.tar.bz2
spark-3fd2ff4dd85633af49865456a52bf0c09c99708b.zip
[SPARK-15807][SQL] Support varargs for dropDuplicates in Dataset/DataFrame
## What changes were proposed in this pull request? This PR adds `varargs`-types `dropDuplicates` functions in `Dataset/DataFrame`. Currently, `dropDuplicates` supports only `Seq` or `Array`. **Before** ```scala scala> val ds = spark.createDataFrame(Seq(("a", 1), ("b", 2), ("a", 2))) ds: org.apache.spark.sql.DataFrame = [_1: string, _2: int] scala> ds.dropDuplicates(Seq("_1", "_2")) res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [_1: string, _2: int] scala> ds.dropDuplicates("_1", "_2") <console>:26: error: overloaded method value dropDuplicates with alternatives: (colNames: Array[String])org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] <and> (colNames: Seq[String])org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] <and> ()org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] cannot be applied to (String, String) ds.dropDuplicates("_1", "_2") ^ ``` **After** ```scala scala> val ds = spark.createDataFrame(Seq(("a", 1), ("b", 2), ("a", 2))) ds: org.apache.spark.sql.DataFrame = [_1: string, _2: int] scala> ds.dropDuplicates("_1", "_2") res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [_1: string, _2: int] ``` ## How was this patch tested? Pass the Jenkins tests with new testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13545 from dongjoon-hyun/SPARK-15807.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala13
3 files changed, 30 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 16bbf30a94..5a67fc79ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1834,6 +1834,19 @@ class Dataset[T] private[sql](
def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
/**
+ * Returns a new [[Dataset]] with duplicate rows removed, considering only
+ * the subset of columns.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def dropDuplicates(col1: String, cols: String*): Dataset[T] = {
+ val colNames: Seq[String] = col1 +: cols
+ dropDuplicates(colNames)
+ }
+
+ /**
* Computes statistics for numeric columns, including count, mean, stddev, min, and max.
* If no columns are given, this function computes statistics for all numerical columns.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a02e48d849..6bb0ce95c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -906,6 +906,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
testData.dropDuplicates(Seq("value2")),
Seq(Row(2, 1, 2), Row(1, 1, 1)))
+
+ checkAnswer(
+ testData.dropDuplicates("key", "value1"),
+ Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
}
test("SPARK-7150 range api") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 11b52bdead..4536a7356f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -806,6 +806,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level non-flat input object"))
}
+
+ test("dropDuplicates") {
+ val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS()
+ checkDataset(
+ ds.dropDuplicates("_1"),
+ ("a", 1), ("b", 1))
+ checkDataset(
+ ds.dropDuplicates("_2"),
+ ("a", 1), ("a", 2))
+ checkDataset(
+ ds.dropDuplicates("_1", "_2"),
+ ("a", 1), ("a", 2), ("b", 1))
+ }
}
case class Generic[T](id: T, value: Double)