From 1c70b7650f21fc51a07db1e4f28cebbc1fb47e94 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 24 Mar 2016 22:56:34 -0700 Subject: [SPARK-14145][SQL] Remove the untyped version of Dataset.groupByKey ## What changes were proposed in this pull request? Dataset has two variants of groupByKey, one for untyped and the other for typed. It actually doesn't make as much sense to have an untyped API here, since apps that want to use untyped APIs should just use the groupBy "DataFrame" API. ## How was this patch tested? This patch removes a method, and removes the associated tests. Author: Reynold Xin Closes #11949 from rxin/SPARK-14145. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 26 ------------ .../org/apache/spark/sql/JavaDatasetSuite.java | 23 ---------- .../org/apache/spark/sql/DatasetCacheSuite.scala | 2 +- .../scala/org/apache/spark/sql/DatasetSuite.scala | 49 ---------------------- 4 files changed, 1 insertion(+), 99 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ec0b3c78ed..703ea4d149 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1178,32 +1178,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * :: Experimental :: - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] - * expressions. - * - * @group typedrel - * @since 2.0.0 - */ - @Experimental - @scala.annotation.varargs - def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) - val withKey = Project(withKeyColumns, logicalPlan) - val executed = sqlContext.executePlan(withKey) - - val dataAttributes = executed.analyzed.output.dropRight(cols.size) - val keyAttributes = executed.analyzed.output.takeRight(cols.size) - - new KeyValueGroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], - executed, - dataAttributes, - keyAttributes) - } - /** * :: Experimental :: * (Java-specific) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 18f17a85a9..86db8df4c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -245,29 +245,6 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList())); } - @Test - public void testGroupByColumn() { - List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, Encoders.STRING()); - KeyValueGroupedDataset grouped = - ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); - - Dataset mapped = grouped.mapGroups( - new MapGroupsFunction() { - @Override - public String call(Integer key, Iterator data) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); - } - return sb.toString(); - } - }, - Encoders.STRING()); - - Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); - } - @Test public void testSelect() { List data = Arrays.asList(2, 6); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 2e5179a8d2..942cc09b6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -63,7 +63,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] + val grouped = ds.groupByKey(_._1) val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } agged.persist() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0bcc512d71..553bc436a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -322,55 +322,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ) } - test("groupBy columns, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1") - val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns, count") { - val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() - val count = ds.groupByKey($"_1").count() - - checkDataset( - count, - (Row("a"), 2L), (Row("b"), 1L)) - } - - test("groupBy columns asKey, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns asKey tuple, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) - } - - test("groupBy columns asKey class, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) - } - test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() -- cgit v1.2.3