aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-16 15:32:49 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-16 15:32:49 -0800
commitfd14936be7beff543dbbcf270f2f9749f7a803c4 (patch)
tree9899abb3516a2b254cc3f961bc356159a72c9f45
parent75ee12f09c2645c1ad682764d512965f641eb5c2 (diff)
downloadspark-fd14936be7beff543dbbcf270f2f9749f7a803c4.tar.gz
spark-fd14936be7beff543dbbcf270f2f9749f7a803c4.tar.bz2
spark-fd14936be7beff543dbbcf270f2f9749f7a803c4.zip
[SPARK-11625][SQL] add java test for typed aggregate
Author: Wenchen Fan <wenchen@databricks.com> Closes #9591 from cloud-fan/agg-test.
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/Function.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java56
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala7
5 files changed, 92 insertions, 9 deletions
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java
index d00551bb0a..b9d9777a75 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java
@@ -25,5 +25,5 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public interface Function<T1, R> extends Serializable {
- public R call(T1 v1) throws Exception;
+ R call(T1 v1) throws Exception;
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index ebcf4c8bfe..467cd42b9b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -145,9 +145,37 @@ class GroupedDataset[K, T] private[sql](
reduce(f.call _)
}
- // To ensure valid overloading.
- protected def agg(expr: Column, exprs: Column*): DataFrame =
- groupedData.agg(expr, exprs: _*)
+ /**
+ * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
+ * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
+ *
+ * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
+ *
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ *
+ * // Scala:
+ * import org.apache.spark.sql.functions._
+ * df.groupBy("department").agg(max("age"), sum("expense"))
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.groupBy("department").agg(max("age"), sum("expense"));
+ * }}}
+ *
+ * We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame =
+ groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
+
+ private def withEncoder(c: Column): Column = c match {
+ case tc: TypedColumn[_, _] =>
+ tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes)
+ case _ => c
+ }
/**
* Internal helper function for building typed aggregations that return tuples. For simplicity
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 360c9a5bc1..72610e735f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
* @tparam B The type of the intermediate value of the reduction.
* @tparam C The type of the final result.
*/
-abstract class Aggregator[-A, B, C] {
+abstract class Aggregator[-A, B, C] extends Serializable {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
def zero: B
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index eb6fa1e72e..d9b22506fb 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -34,6 +34,7 @@ import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.GroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import static org.apache.spark.sql.functions.*;
@@ -381,4 +382,59 @@ public class JavaDatasetSuite implements Serializable {
context.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
}
+
+ @Test
+ public void testTypedAggregation() {
+ Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
+ List<Tuple2<String, Integer>> data =
+ Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+ Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+ GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy(
+ new MapFunction<Tuple2<String, Integer>, String>() {
+ @Override
+ public String call(Tuple2<String, Integer> value) throws Exception {
+ return value._1();
+ }
+ },
+ Encoders.STRING());
+
+ Dataset<Tuple2<String, Integer>> agged =
+ grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+
+ Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
+ new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
+ expr("sum(_2)"),
+ count("*"))
+ .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
+ Assert.assertEquals(
+ Arrays.asList(
+ new Tuple4<String, Integer, Long, Long>("a", 3, 3L, 2L),
+ new Tuple4<String, Integer, Long, Long>("b", 3, 3L, 1L)),
+ agged2.collectAsList());
+ }
+
+ static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
+
+ @Override
+ public Integer zero() {
+ return 0;
+ }
+
+ @Override
+ public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+ return l + t._2();
+ }
+
+ @Override
+ public Integer merge(Integer b1, Integer b2) {
+ return b1 + b2;
+ }
+
+ @Override
+ public Integer finish(Integer reduction) {
+ return reduction;
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 46f9f077fe..9377589790 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Aggregator
/** An `Aggregator` that adds up any numeric type returned by the given function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
val numeric = implicitly[Numeric[N]]
override def zero: N = numeric.zero
@@ -37,7 +37,7 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ
override def finish(reduction: N): N = reduction
}
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable {
+object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
override def zero: (Long, Long) = (0, 0)
override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
@@ -51,8 +51,7 @@ object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with
override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
}
-object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)]
- with Serializable {
+object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
override def zero: (Long, Long) = (0, 0)