aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
commit9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 (patch)
treec821f0b8bbcce9410bdc5b54968251f8bdfb0b6a /python/pyspark/mllib/clustering.py
parent2f38378856fb56bdd9be7ccedf56427e81701f4e (diff)
downloadspark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.gz
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.bz2
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.zip
[SPARK-11578][SQL] User API for Typed Aggregation
This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust <michael@databricks.com> Closes #9555 from marmbrus/dataset-useragg.
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
0 files changed, 0 insertions, 0 deletions