diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-16 15:32:49 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-16 15:32:49 -0800 |
commit | fd14936be7beff543dbbcf270f2f9749f7a803c4 (patch) | |
tree | 9899abb3516a2b254cc3f961bc356159a72c9f45 /sql/core | |
parent | 75ee12f09c2645c1ad682764d512965f641eb5c2 (diff) | |
download | spark-fd14936be7beff543dbbcf270f2f9749f7a803c4.tar.gz spark-fd14936be7beff543dbbcf270f2f9749f7a803c4.tar.bz2 spark-fd14936be7beff543dbbcf270f2f9749f7a803c4.zip |
[SPARK-11625][SQL] add java test for typed aggregate
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9591 from cloud-fan/agg-test.
Diffstat (limited to 'sql/core')
4 files changed, 91 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index ebcf4c8bfe..467cd42b9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -145,9 +145,37 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - // To ensure valid overloading. - protected def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(expr, exprs: _*) + /** + * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. + * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg(max("age"), sum("expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(max("age"), sum("expense")); + * }}} + * + * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = + groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + + private def withEncoder(c: Column): Column = c match { + case tc: TypedColumn[_, _] => + tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + case _ => c + } /** * Internal helper function for building typed aggregations that return tuples. For simplicity diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 360c9a5bc1..72610e735f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * @tparam B The type of the intermediate value of the reduction. * @tparam C The type of the final result. */ -abstract class Aggregator[-A, B, C] { +abstract class Aggregator[-A, B, C] extends Serializable { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ def zero: B diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index eb6fa1e72e..d9b22506fb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import static org.apache.spark.sql.functions.*; @@ -381,4 +382,59 @@ public class JavaDatasetSuite implements Serializable { context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } + + @Test + public void testTypedAggregation() { + Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List<Tuple2<String, Integer>> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); + + GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( + new MapFunction<Tuple2<String, Integer>, String>() { + @Override + public String call(Tuple2<String, Integer> value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + + Dataset<Tuple2<String, Integer>> agged = + grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), + expr("sum(_2)"), + count("*")) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Assert.assertEquals( + Arrays.asList( + new Tuple4<String, Integer, Long, Long>("a", 3, 3L, 2L), + new Tuple4<String, Integer, Long, Long>("b", 3, 3L, 1L)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> { + + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2<String, Integer> t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 46f9f077fe..9377589790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator /** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero @@ -37,7 +37,7 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ override def finish(reduction: N): N = reduction } -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { override def zero: (Long, Long) = (0, 0) override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { @@ -51,8 +51,7 @@ object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 } -object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] - with Serializable { +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { override def zero: (Long, Long) = (0, 0) |