diff options
author | Michael Armbrust <michael@databricks.com> | 2015-11-09 16:11:00 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-09 16:11:00 -0800 |
commit | 9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 (patch) | |
tree | c821f0b8bbcce9410bdc5b54968251f8bdfb0b6a /python/pyspark/mllib | |
parent | 2f38378856fb56bdd9be7ccedf56427e81701f4e (diff) | |
download | spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.gz spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.bz2 spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.zip |
[SPARK-11578][SQL] User API for Typed Aggregation
This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value.
For example, the following aggregator extracts an `int` from a specific class and adds them up:
```scala
case class Data(i: Int)
val customSummer = new Aggregator[Data, Int, Int] {
def prepare(d: Data) = d.i
def reduce(l: Int, r: Int) = l + r
def present(r: Int) = r
}.toColumn()
val ds: Dataset[Data] = ...
val aggregated = ds.select(customSummer)
```
By using helper functions, users can make a generic `Aggregator` that works on any input type:
```scala
/** An `Aggregator` that adds up any numeric type returned by the given function. */
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
val numeric = implicitly[Numeric[N]]
override def zero: N = numeric.zero
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
override def present(reduction: N): N = reduction
}
def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
```
These aggregators can then be used alongside other built-in SQL aggregations.
```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds
.groupBy(_._1)
.agg(
sum(_._2), // The aggregator defined above.
expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation.
count("*")) // A built-in statically typed aggregation.
.collect()
res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)
```
The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR.
Author: Michael Armbrust <michael@databricks.com>
Closes #9555 from marmbrus/dataset-useragg.
Diffstat (limited to 'python/pyspark/mllib')
0 files changed, 0 insertions, 0 deletions