aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-08 12:59:35 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-08 12:59:35 -0800
commitb2d195e137fad88d567974659fa7023ff4da96cd (patch)
treeaed5f4a83f961c61ce617b90ec458b8a2f91ce12 /sql
parent26739059bc39cd7aa7e0b1c16089c1cf8d8e4d7d (diff)
downloadspark-b2d195e137fad88d567974659fa7023ff4da96cd.tar.gz
spark-b2d195e137fad88d567974659fa7023ff4da96cd.tar.bz2
spark-b2d195e137fad88d567974659fa7023ff4da96cd.zip
[SPARK-11554][SQL] add map/flatMap to GroupedDataset
Author: Wenchen Fan <wenchen@databricks.com> Closes #9521 from cloud-fan/map.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala40
6 files changed, 70 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09aac00a45..e151ac04ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -494,7 +494,7 @@ case class AppendColumn[T, U](
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
- func: (K, Iterator[T]) => Iterator[U],
+ func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
@@ -514,7 +514,7 @@ object MapGroups {
* object representation of all the rows with that key.
*/
case class MapGroups[K, T, U](
- func: (K, Iterator[T]) => Iterator[U],
+ func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b2803d5a9a..5c3f626545 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql](
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*/
- def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
+ def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
MapGroups(f, groupingAttributes, logicalPlan))
}
- def mapGroups[U](
+ def flatMap[U](
f: JFunction2[K, JIterator[T], JIterator[U]],
encoder: Encoder[U]): Dataset[U] = {
- mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
+ flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
+ }
+
+ /**
+ * Applies the given function to each group of data. For each unique group, the function will
+ * be passed the group key and an iterator that contains all of the elements in the group. The
+ * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ */
+ def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
+ val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
+ new Dataset[U](
+ sqlContext,
+ MapGroups(func, groupingAttributes, logicalPlan))
+ }
+
+ def map[U](
+ f: JFunction2[K, JIterator[T], U],
+ encoder: Encoder[U]): Dataset[U] = {
+ map((key, data) => f.call(key, data.asJava))(encoder)
}
// To ensure valid overloading.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 799650a4f7..2593b16b1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -356,7 +356,7 @@ case class AppendColumns[T, U](
* being output.
*/
case class MapGroups[K, T, U](
- func: (K, Iterator[T]) => Iterator[U],
+ func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a9493d576d..0d3b1a5af5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable {
}
}, e.INT());
- Dataset<String> mapped = grouped.mapGroups(
- new Function2<Integer, Iterator<String>, Iterator<String>>() {
+ Dataset<String> mapped = grouped.map(
+ new Function2<Integer, Iterator<String>, String>() {
@Override
- public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+ public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
while (data.hasNext()) {
sb.append(data.next());
}
- return Collections.singletonList(sb.toString()).iterator();
+ return sb.toString();
}
},
e.STRING());
@@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
- Dataset<String> mapped = grouped.mapGroups(
- new Function2<Integer, Iterator<String>, Iterator<String>>() {
+ Dataset<String> mapped = grouped.map(
+ new Function2<Integer, Iterator<String>, String>() {
@Override
- public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+ public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
while (data.hasNext()) {
sb.append(data.next());
}
- return Collections.singletonList(sb.toString()).iterator();
+ return sb.toString();
}
},
e.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index e3b0346f85..fcf03f7180 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
0, 1)
}
- test("groupBy function, mapGroups") {
+ test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
val grouped = ds.groupBy(_ % 2)
- val agged = grouped.mapGroups { case (g, iter) =>
+ val agged = grouped.map { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
- Iterator((name, iter.size))
+ (name, iter.size)
}
checkAnswer(
agged,
("even", 5), ("odd", 6))
}
+
+ test("groupBy function, flatMap") {
+ val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
+ val grouped = ds.groupBy(_.length)
+ val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
+
+ checkAnswer(
+ agged,
+ "1", "abc", "3", "xyz", "5", "hello")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d61e17edc6..6f1174e657 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(1, 1))
}
- test("groupBy function, mapGroups") {
+ test("groupBy function, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
- val agged = grouped.mapGroups { case (g, iter) =>
- Iterator((g._1, iter.map(_._2).sum))
- }
+ val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
- test("groupBy columns, mapGroups") {
+ test("groupBy function, fatMap") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy(v => (v._1, "word"))
+ val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }
+
+ checkAnswer(
+ agged,
+ "a", "30", "b", "3", "c", "1")
+ }
+
+ test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
- val agged = grouped.mapGroups { case (g, iter) =>
- Iterator((g.getString(0), iter.map(_._2).sum))
- }
+ val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
- test("groupBy columns asKey, mapGroups") {
+ test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").asKey[String]
- val agged = grouped.mapGroups { case (g, iter) =>
- Iterator((g, iter.map(_._2).sum))
- }
+ val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
- test("groupBy columns asKey tuple, mapGroups") {
+ test("groupBy columns asKey tuple, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
- val agged = grouped.mapGroups { case (g, iter) =>
- Iterator((g, iter.map(_._2).sum))
- }
+ val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
(("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
}
- test("groupBy columns asKey class, mapGroups") {
+ test("groupBy columns asKey class, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
- val agged = grouped.mapGroups { case (g, iter) =>
- Iterator((g, iter.map(_._2).sum))
- }
+ val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,