diff options
author | Eric Liang <ekl@databricks.com> | 2016-04-05 00:30:55 -0500 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-05 00:30:55 -0500 |
commit | 064623014e0d6dfb0376722f24e81027fde649de (patch) | |
tree | 4ef26a921ede1724428746ae97f414e705ac9033 /sql/core/src | |
parent | 7db56244fa3dba92246bad6694f31bbf68ea47ec (diff) | |
download | spark-064623014e0d6dfb0376722f24e81027fde649de.tar.gz spark-064623014e0d6dfb0376722f24e81027fde649de.tar.bz2 spark-064623014e0d6dfb0376722f24e81027fde649de.zip |
[SPARK-14359] Create built-in functions for typed aggregates in Java
## What changes were proposed in this pull request?
This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala.
## How was this patch tested?
Unit tests.
rxin
Author: Eric Liang <ekl@databricks.com>
Closes #12168 from ericl/sc-2794.
Diffstat (limited to 'sql/core/src')
3 files changed, 124 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 9afc29038b..7a18d0afce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) override def finish(reduction: OUT): OUT = reduction + + // TODO(ekl) java api support once this is exposed in scala } @@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } @@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { } override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java index cdba970d8f..8ff7b6549b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -18,7 +18,13 @@ package org.apache.spark.sql.expressions.java; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.execution.aggregate.TypedAverage; +import org.apache.spark.sql.execution.aggregate.TypedCount; +import org.apache.spark.sql.execution.aggregate.TypedSumDouble; +import org.apache.spark.sql.execution.aggregate.TypedSumLong; /** * :: Experimental :: @@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset; */ @Experimental public class typed { + // Note: make sure to keep in sync with typed.scala + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + public static<T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) { + return new TypedAverage<T>(f).toColumnJava(); + } + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + public static<T> TypedColumn<T, Long> count(MapFunction<T, Object> f) { + return new TypedCount<T>(f).toColumnJava(); + } + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + public static<T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) { + return new TypedSumDouble<T>(f).toColumnJava(); + } + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + public static<T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) { + return new TypedSumLong<T>(f).toColumnJava(); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c4c455b6e6..c8d0eecd5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.expressions.java.typed; import org.apache.spark.sql.test.TestSQLContext; /** @@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable { return reduction; } } + + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg( + new MapFunction<Tuple2<String, Integer>, Double>() { + public Double call(Tuple2<String, Integer> value) throws Exception { + return (double)(value._2() * 2); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count( + new MapFunction<Tuple2<String, Integer>, Object>() { + public Object call(Tuple2<String, Integer> value) throws Exception { + return value; + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum( + new MapFunction<Tuple2<String, Integer>, Double>() { + public Double call(Tuple2<String, Integer> value) throws Exception { + return (double)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong( + new MapFunction<Tuple2<String, Integer>, Long>() { + public Long call(Tuple2<String, Integer> value) throws Exception { + return (long)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } } |