aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
commit9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 (patch)
treec821f0b8bbcce9410bdc5b54968251f8bdfb0b6a /sql/core/src/test/java
parent2f38378856fb56bdd9be7ccedf56427e81701f4e (diff)
downloadspark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.gz
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.bz2
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.zip
[SPARK-11578][SQL] User API for Typed Aggregation
This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust <michael@databricks.com> Closes #9555 from marmbrus/dataset-useragg.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
1 files changed, 2 insertions, 2 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 312cf33e4c..2da63d1b96 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable {
Dataset<Integer> ds = context.createDataset(data, e.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
- expr("value + 1").as(e.INT()),
- col("value").cast("string").as(e.STRING()));
+ expr("value + 1"),
+ col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
Assert.assertEquals(
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),