diff options
author | Reynold Xin <rxin@databricks.com> | 2016-03-19 11:23:14 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-03-19 11:23:14 -0700 |
commit | dcaa016610ac2c11d7dd01803f3515b02ab32e64 (patch) | |
tree | 7d03000193cdcc5100fd7198e143680b2e5882e5 /sql | |
parent | 2082a49569cb5d900e318af9da1027821dfe93bc (diff) | |
download | spark-dcaa016610ac2c11d7dd01803f3515b02ab32e64.tar.gz spark-dcaa016610ac2c11d7dd01803f3515b02ab32e64.tar.bz2 spark-dcaa016610ac2c11d7dd01803f3515b02ab32e64.zip |
[SPARK-13897][SQL] RelationalGroupedDataset and KeyValueGroupedDataset
## What changes were proposed in this pull request?
Previously, Dataset.groupBy returns a GroupedData, and Dataset.groupByKey returns a GroupedDataset. The naming is very similar, and unfortunately does not convey the real differences between the two.
Assume we are grouping by some keys (K). groupByKey is a key-value style group by, in which the schema of the returned dataset is a tuple of just two fields: key and value. groupBy, on the other hand, is a relational style group by, in which the schema of the returned dataset is flattened and contain |K| + |V| fields.
This pull request also removes the experimental tag from RelationalGroupedDataset. It has been with DataFrame since 1.3, and we have enough confidence now to stabilize it.
## How was this patch tested?
This is a rename to improve API understandability. Should be covered by all existing tests.
Author: Reynold Xin <rxin@databricks.com>
Closes #11841 from rxin/SPARK-13897.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 56 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala) | 35 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala) | 37 | ||||
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 8 |
4 files changed, 69 insertions, 67 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 39f7f35def..6e7d208723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1036,7 +1036,7 @@ class Dataset[T] private[sql]( /** * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. See - * [[GroupedData]] for all the available aggregate functions. + * [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns grouped by department. @@ -1053,14 +1053,14 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType) + def groupBy(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) } /** * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns rolluped by department and group. @@ -1077,14 +1077,14 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType) + def rollup(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) } /** * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns cubed by department and group. @@ -1101,11 +1101,13 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType) + def cube(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) + } /** * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1124,9 +1126,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedData = { + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType) } /** @@ -1156,18 +1159,18 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. * * @group typedrel * @since 2.0.0 */ @Experimental - def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = { + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) - new GroupedDataset( + new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, @@ -1177,14 +1180,15 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] + * expressions. * * @group typedrel * @since 2.0.0 */ @Experimental @scala.annotation.varargs - def groupByKey(cols: Column*): GroupedDataset[Row, T] = { + def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) @@ -1192,7 +1196,7 @@ class Dataset[T] private[sql]( val dataAttributes = executed.analyzed.output.dropRight(cols.size) val keyAttributes = executed.analyzed.output.takeRight(cols.size) - new GroupedDataset( + new KeyValueGroupedDataset( RowEncoder(keyAttributes.toStructType), encoderFor[T], executed, @@ -1203,19 +1207,19 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. * * @group typedrel * @since 2.0.0 */ @Experimental - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) /** * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of rollup that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1235,15 +1239,16 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def rollup(col1: String, cols: String*): GroupedData = { + def rollup(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType) } /** * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of cube that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1262,9 +1267,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def cube(col1: String, cols: String*): GroupedData = { + def cube(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a8700de135..f0f96825e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,18 +29,13 @@ import org.apache.spark.sql.execution.QueryExecution /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not - * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupBy` on an existing * [[Dataset]]. * - * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, - * making this change to the class hierarchy would break some function signatures. As such, this - * class should be considered a preview of the final API. Changes will be made to the interface - * after Spark 1.6. - * - * @since 1.6.0 + * @since 2.0.0 */ @Experimental -class GroupedDataset[K, V] private[sql]( +class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], val queryExecution: QueryExecution, @@ -62,18 +57,22 @@ class GroupedDataset[K, V] private[sql]( private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext - private def groupedData = - new GroupedData( - Dataset.newDataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + private def groupedData = { + new RelationalGroupedDataset( + Dataset.newDataFrame(sqlContext, logicalPlan), + groupingAttributes, + RelationalGroupedDataset.GroupByType) + } /** - * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified - * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the + * specified type. The mapping of key columns to the type follows the same rules as `as` on + * [[Dataset]]. * * @since 1.6.0 */ - def keyAs[L : Encoder]: GroupedDataset[L, V] = - new GroupedDataset( + def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = + new KeyValueGroupedDataset( encoderFor[L], unresolvedVEncoder, queryExecution, @@ -294,7 +293,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]())) /** * Applies the given function to each cogrouped data. For each unique group, the function will @@ -305,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def cogroup[U, R : Encoder]( - other: GroupedDataset[K, U])( + other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit val uEncoder = other.unresolvedVEncoder Dataset[R]( @@ -329,7 +328,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def cogroup[U, R]( - other: GroupedDataset[K, U], + other: KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 04d277bed2..521032a8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -30,19 +29,17 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType /** - * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. * * The main method is the agg function, which has multiple variants. This class also contains * convenience some first order statistics such as mean, sum for convenience. * - * @since 1.3.0 + * @since 2.0.0 */ -@Experimental -class GroupedData protected[sql]( +class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], - groupType: GroupedData.GroupType) { + groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { @@ -54,16 +51,16 @@ class GroupedData protected[sql]( val aliasedAgg = aggregates.map(alias) groupType match { - case GroupedData.GroupByType => + case RelationalGroupedDataset.GroupByType => Dataset.newDataFrame( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) - case GroupedData.RollupType => + case RelationalGroupedDataset.RollupType => Dataset.newDataFrame( df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.CubeType => + case RelationalGroupedDataset.CubeType => Dataset.newDataFrame( df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.PivotType(pivotCol, values) => + case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.newDataFrame( df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) @@ -299,7 +296,7 @@ class GroupedData protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): GroupedData = { + def pivot(pivotColumn: String): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) // Get the distinct values of the column and sort them so its consistent @@ -340,14 +337,14 @@ class GroupedData protected[sql]( * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { groupType match { - case GroupedData.GroupByType => - new GroupedData( + case RelationalGroupedDataset.GroupByType => + new RelationalGroupedDataset( df, groupingExprs, - GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) - case _: GroupedData.PivotType => + RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => throw new UnsupportedOperationException("pivot is only supported after a groupBy") @@ -372,7 +369,7 @@ class GroupedData protected[sql]( * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } } @@ -381,13 +378,13 @@ class GroupedData protected[sql]( /** * Companion object for GroupedData. */ -private[sql] object GroupedData { +private[sql] object RelationalGroupedDataset { def apply( df: DataFrame, groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) + groupType: GroupType): RelationalGroupedDataset = { + new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 79b6e61767..4b8b0d9d4f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { + KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { + KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -249,7 +249,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = + KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { |