diff options
Diffstat (limited to 'sql/core/src')
4 files changed, 1 insertions, 99 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ec0b3c78ed..703ea4d149 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1180,32 +1180,6 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] - * expressions. - * - * @group typedrel - * @since 2.0.0 - */ - @Experimental - @scala.annotation.varargs - def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) - val withKey = Project(withKeyColumns, logicalPlan) - val executed = sqlContext.executePlan(withKey) - - val dataAttributes = executed.analyzed.output.dropRight(cols.size) - val keyAttributes = executed.analyzed.output.takeRight(cols.size) - - new KeyValueGroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], - executed, - dataAttributes, - keyAttributes) - } - - /** - * :: Experimental :: * (Java-specific) * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 18f17a85a9..86db8df4c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -246,29 +246,6 @@ public class JavaDatasetSuite implements Serializable { } @Test - public void testGroupByColumn() { - List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - KeyValueGroupedDataset<Integer, String> grouped = - ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); - - Dataset<String> mapped = grouped.mapGroups( - new MapGroupsFunction<Integer, String, String>() { - @Override - public String call(Integer key, Iterator<String> data) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); - } - return sb.toString(); - } - }, - Encoders.STRING()); - - Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); - } - - @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 2e5179a8d2..942cc09b6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -63,7 +63,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] + val grouped = ds.groupByKey(_._1) val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } agged.persist() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0bcc512d71..553bc436a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -322,55 +322,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ) } - test("groupBy columns, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1") - val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns, count") { - val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() - val count = ds.groupByKey($"_1").count() - - checkDataset( - count, - (Row("a"), 2L), (Row("b"), 1L)) - } - - test("groupBy columns asKey, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1").keyAs[String] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - ("a", 30), ("b", 3), ("c", 1)) - } - - test("groupBy columns asKey tuple, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) - } - - test("groupBy columns asKey class, map") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - - checkDataset( - agged, - (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) - } - test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() |