From 064623014e0d6dfb0376722f24e81027fde649de Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 Apr 2016 00:30:55 -0500 Subject: [SPARK-14359] Create built-in functions for typed aggregates in Java ## What changes were proposed in this pull request? This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala. ## How was this patch tested? Unit tests. rxin Author: Eric Liang Closes #12168 from ericl/sc-2794. --- .../sql/execution/aggregate/typedaggregators.scala | 33 +++++++++++++++ .../apache/spark/sql/expressions/java/typed.java | 42 +++++++++++++++++++ .../sql/sources/JavaDatasetAggregatorSuite.java | 49 ++++++++++++++++++++++ 3 files changed, 124 insertions(+) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 9afc29038b..7a18d0afce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) override def finish(reduction: OUT): OUT = reduction + + // TODO(ekl) java api support once this is exposed in scala } @@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } @@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { } override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java index cdba970d8f..8ff7b6549b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -18,7 +18,13 @@ package org.apache.spark.sql.expressions.java; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.execution.aggregate.TypedAverage; +import org.apache.spark.sql.execution.aggregate.TypedCount; +import org.apache.spark.sql.execution.aggregate.TypedSumDouble; +import org.apache.spark.sql.execution.aggregate.TypedSumLong; /** * :: Experimental :: @@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset; */ @Experimental public class typed { + // Note: make sure to keep in sync with typed.scala + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn avg(MapFunction f) { + return new TypedAverage(f).toColumnJava(); + } + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn count(MapFunction f) { + return new TypedCount(f).toColumnJava(); + } + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + public static TypedColumn sum(MapFunction f) { + return new TypedSumDouble(f).toColumnJava(); + } + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + public static TypedColumn sumLong(MapFunction f) { + return new TypedSumLong(f).toColumnJava(); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c4c455b6e6..c8d0eecd5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.expressions.java.typed; import org.apache.spark.sql.test.TestSQLContext; /** @@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable { return reduction; } } + + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.avg( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)(value._2() * 2); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.count( + new MapFunction, Object>() { + public Object call(Tuple2 value) throws Exception { + return value; + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sum( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sumLong( + new MapFunction, Long>() { + public Long call(Tuple2 value) throws Exception { + return (long)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } } -- cgit v1.2.3