aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-04-05 00:30:55 -0500
committerReynold Xin <rxin@databricks.com>2016-04-05 00:30:55 -0500
commit064623014e0d6dfb0376722f24e81027fde649de (patch)
tree4ef26a921ede1724428746ae97f414e705ac9033
parent7db56244fa3dba92246bad6694f31bbf68ea47ec (diff)
downloadspark-064623014e0d6dfb0376722f24e81027fde649de.tar.gz
spark-064623014e0d6dfb0376722f24e81027fde649de.tar.bz2
spark-064623014e0d6dfb0376722f24e81027fde649de.zip
[SPARK-14359] Create built-in functions for typed aggregates in Java
## What changes were proposed in this pull request? This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala. ## How was this patch tested? Unit tests. rxin Author: Eric Liang <ekl@databricks.com> Closes #12168 from ericl/sc-2794.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java42
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java49
3 files changed, 124 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 9afc29038b..7a18d0afce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT]
override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
override def finish(reduction: OUT): OUT = reduction
+
+ // TODO(ekl) java api support once this is exposed in scala
}
@@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
override def finish(reduction: Double): Double = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
}
@@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
+ def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
}
@@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
}
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+
+ // Java api support
+ def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
+ def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
}
@@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
+ toColumn(ExpressionEncoder(), ExpressionEncoder())
+ .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
index cdba970d8f..8ff7b6549b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -18,7 +18,13 @@
package org.apache.spark.sql.expressions.java;
import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.execution.aggregate.TypedAverage;
+import org.apache.spark.sql.execution.aggregate.TypedCount;
+import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
+import org.apache.spark.sql.execution.aggregate.TypedSumLong;
/**
* :: Experimental ::
@@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset;
*/
@Experimental
public class typed {
+ // Note: make sure to keep in sync with typed.scala
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) {
+ return new TypedAverage<T>(f).toColumnJava();
+ }
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Long> count(MapFunction<T, Object> f) {
+ return new TypedCount<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) {
+ return new TypedSumDouble<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ public static<T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
+ return new TypedSumLong<T>(f).toColumnJava();
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index c4c455b6e6..c8d0eecd5c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.expressions.java.typed;
import org.apache.spark.sql.test.TestSQLContext;
/**
@@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable {
return reduction;
}
}
+
+ @Test
+ public void testTypedAggregationAverage() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)(value._2() * 2);
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationCount() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
+ new MapFunction<Tuple2<String, Integer>, Object>() {
+ public Object call(Tuple2<String, Integer> value) throws Exception {
+ return value;
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumDouble() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumLong() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
+ new MapFunction<Tuple2<String, Integer>, Long>() {
+ public Long call(Tuple2<String, Integer> value) throws Exception {
+ return (long)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ }
}