aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-21 15:00:37 -0800
committerReynold Xin <rxin@databricks.com>2015-11-21 15:00:37 -0800
commitff442bbcffd4f93cfcc2f76d160011e725d2fb3f (patch)
tree9a0a5756f29de2f3021cf2c0c9ac5aae6bc7e7e2
parent596710268e29e8f624c3ba2fade08b66ec7084eb (diff)
downloadspark-ff442bbcffd4f93cfcc2f76d160011e725d2fb3f.tar.gz
spark-ff442bbcffd4f93cfcc2f76d160011e725d2fb3f.tar.bz2
spark-ff442bbcffd4f93cfcc2f76d160011e725d2fb3f.zip
[SPARK-11899][SQL] API audit for GroupedDataset.
1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin <rxin@databricks.com> Closes #9880 from rxin/SPARK-11899.
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala132
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala20
9 files changed, 131 insertions, 45 deletions
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
index 2935f9986a..4f3f222e06 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
@@ -21,7 +21,7 @@ import java.io.Serializable;
import java.util.Iterator;
/**
- * Base interface for a map function used in GroupedDataset's map function.
+ * Base interface for a map function used in GroupedDataset's mapGroup function.
*/
public interface MapGroupFunction<K, V, R> extends Serializable {
R call(K key, Iterator<V> values) throws Exception;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 5cb8edf64e..03aa25eda8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.types._
*
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
+ *
+ * @since 1.6.0
*/
trait Encoder[T] extends Serializable {
@@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable {
/**
* Methods for creating encoders.
+ *
+ * @since 1.6.0
*/
object Encoders {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 88a457f87c..7d4cfbe6fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
/**
* Type-inference utilities for POJOs and Java collections.
*/
-private [sql] object JavaTypeInference {
+object JavaTypeInference {
private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
@@ -53,7 +53,6 @@ private [sql] object JavaTypeInference {
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
- // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 82e9cd7f50..30c554a85e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -46,6 +46,8 @@ private[sql] object Column {
* @tparam T The input type expected for this expression. Can be `Any` if the expression is type
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
* @tparam U The output type of this column.
+ *
+ * @since 1.6.0
*/
class TypedColumn[-T, U](
expr: Expression,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 7abcecaa28..5586fc994b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -110,7 +110,6 @@ private[sql] object DataFrame {
* @groupname action Actions
* @since 1.3.0
*/
-// TODO: Improve documentation.
@Experimental
class DataFrame private[sql](
@transient val sqlContext: SQLContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 263f049104..7f43ce1690 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.Aggregator
/**
* :: Experimental ::
@@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution
* making this change to the class hierarchy would break some function signatures. As such, this
* class should be considered a preview of the final API. Changes will be made to the interface
* after Spark 1.6.
+ *
+ * @since 1.6.0
*/
@Experimental
-class GroupedDataset[K, T] private[sql](
+class GroupedDataset[K, V] private[sql](
kEncoder: Encoder[K],
- tEncoder: Encoder[T],
+ tEncoder: Encoder[V],
val queryExecution: QueryExecution,
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
@@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql](
/**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
+ *
+ * @since 1.6.0
*/
- def asKey[L : Encoder]: GroupedDataset[L, T] =
+ def keyAs[L : Encoder]: GroupedDataset[L, V] =
new GroupedDataset(
encoderFor[L],
unresolvedTEncoder,
@@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql](
/**
* Returns a [[Dataset]] that contains each unique key.
+ *
+ * @since 1.6.0
*/
def keys: Dataset[K] = {
new Dataset[K](
@@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql](
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
*
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an [[Aggregator]].
+ *
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
+ *
+ * @since 1.6.0
*/
- def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
+ def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
MapGroups(
@@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql](
logicalPlan))
}
- def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
- flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
+ /**
+ * Applies the given function to each group of data. For each unique group, the function will
+ * be passed the group key and an iterator that contains all of the elements in the group. The
+ * function can return an iterator containing elements of an arbitrary type which will be returned
+ * as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an [[Aggregator]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+ flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
/**
@@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql](
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
*
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an [[Aggregator]].
+ *
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
+ *
+ * @since 1.6.0
*/
- def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
- val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
- flatMap(func)
+ def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+ val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
+ flatMapGroup(func)
}
- def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
- map((key, data) => f.call(key, data.asJava))(encoder)
+ /**
+ * Applies the given function to each group of data. For each unique group, the function will
+ * be passed the group key and an iterator that contains all of the elements in the group. The
+ * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an [[Aggregator]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+ mapGroup((key, data) => f.call(key, data.asJava))(encoder)
}
/**
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
+ *
+ * @since 1.6.0
*/
- def reduce(f: (T, T) => T): Dataset[(K, T)] = {
- val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
+ def reduce(f: (V, V) => V): Dataset[(K, V)] = {
+ val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
- flatMap(func)
+ flatMapGroup(func)
}
- def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
+ /**
+ * Reduces the elements of each group of data using the specified binary function.
+ * The given function must be commutative and associative or the result may be non-deterministic.
+ *
+ * @since 1.6.0
+ */
+ def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
reduce(f.call _)
}
@@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql](
/**
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
* and the result of computing this aggregation over all elements in the group.
+ *
+ * @since 1.6.0
*/
- def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
+ def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
*/
- def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
+ def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
*/
def agg[U1, U2, U3](
- col1: TypedColumn[T, U1],
- col2: TypedColumn[T, U2],
- col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
*/
def agg[U1, U2, U3, U4](
- col1: TypedColumn[T, U1],
- col2: TypedColumn[T, U2],
- col3: TypedColumn[T, U3],
- col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
/**
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
+ *
+ * @since 1.6.0
*/
def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]))
@@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql](
* be passed the grouping key and 2 iterators containing all elements in the group from
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
* arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * @since 1.6.0
*/
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
- f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+ f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
new Dataset[R](
sqlContext,
@@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql](
other.logicalPlan))
}
+ /**
+ * Applies the given function to each cogrouped data. For each unique group, the function will
+ * be passed the grouping key and 2 iterators containing all elements in the group from
+ * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
+ * arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * @since 1.6.0
+ */
def cogroup[U, R](
other: GroupedDataset[K, U],
- f: CoGroupFunction[K, T, U, R],
+ f: CoGroupFunction[K, V, U, R],
encoder: Encoder[R]): Dataset[R] = {
cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f32374b4c0..cf335efdd2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -170,7 +170,7 @@ public class JavaDatasetSuite implements Serializable {
}
}, Encoders.INT());
- Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
+ Dataset<String> mapped = grouped.mapGroup(new MapGroupFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> values) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
@@ -183,7 +183,7 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
- Dataset<String> flatMapped = grouped.flatMap(
+ Dataset<String> flatMapped = grouped.flatMapGroup(
new FlatMapGroupFunction<Integer, String, String>() {
@Override
public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
@@ -247,9 +247,9 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped =
- ds.groupBy(length(col("value"))).asKey(Encoders.INT());
+ ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
- Dataset<String> mapped = grouped.map(
+ Dataset<String> mapped = grouped.mapGroup(
new MapGroupFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> data) throws Exception {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 63b00975e4..d387710357 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
val grouped = ds.groupBy(_ % 2)
- val agged = grouped.map { case (g, iter) =>
+ val agged = grouped.mapGroup { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
(name, iter.size)
}
@@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, flatMap") {
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
val grouped = ds.groupBy(_.length)
- val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
+ val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) }
checkAnswer(
agged,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 89d964aa3e..9da02550b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
- val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
+ val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) }
checkAnswer(
agged,
@@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
- val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }
+ val agged = grouped.flatMapGroup { case (g, iter) =>
+ Iterator(g._1, iter.map(_._2).sum.toString)
+ }
checkAnswer(
agged,
@@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
- val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
+ val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
checkAnswer(
agged,
@@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1").asKey[String]
- val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+ val grouped = ds.groupBy($"_1").keyAs[String]
+ val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
@@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey tuple, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
- val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+ val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)]
+ val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
@@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey class, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
- val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+ val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
+ val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,