aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala39
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala75
4 files changed, 102 insertions, 76 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 7a18d0afce..c39a78da6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.{Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
@@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator
////////////////////////////////////////////////////////////////////////////////////////////////////
-class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
- val numeric = implicitly[Numeric[OUT]]
- override def zero: OUT = numeric.zero
- override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
- override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
- override def finish(reduction: OUT): OUT = reduction
-
- // TODO(ekl) java api support once this is exposed in scala
-}
-
-
class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
override def zero: Double = 0.0
override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2
override def finish(reduction: Double): Double = reduction
+ override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
- def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
- toColumn(ExpressionEncoder(), ExpressionEncoder())
- .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}
@@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
// Java api support
def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
- def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
- toColumn(ExpressionEncoder(), ExpressionEncoder())
- .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}
@@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
// Java api support
def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
- def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
- toColumn(ExpressionEncoder(), ExpressionEncoder())
- .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
}
@@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D
(b1._1 + b2._1, b1._2 + b2._2)
}
+ override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
// Java api support
def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
- def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
- toColumn(ExpressionEncoder(), ExpressionEncoder())
- .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 9cb356f1ca..7da8379c9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
- * @tparam I The input type for the aggregation.
- * @tparam B The type of the intermediate value of the reduction.
- * @tparam O The type of the final output result.
+ * @tparam IN The input type for the aggregation.
+ * @tparam BUF The type of the intermediate value of the reduction.
+ * @tparam OUT The type of the final output result.
* @since 1.6.0
*/
-abstract class Aggregator[-I, B, O] extends Serializable {
+abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
/**
* A zero value for this aggregation. Should satisfy the property that any b + zero = b.
* @since 1.6.0
*/
- def zero: B
+ def zero: BUF
/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
* @since 1.6.0
*/
- def reduce(b: B, a: I): B
+ def reduce(b: BUF, a: IN): BUF
/**
* Merge two intermediate values.
* @since 1.6.0
*/
- def merge(b1: B, b2: B): B
+ def merge(b1: BUF, b2: BUF): BUF
/**
* Transform the output of the reduction.
* @since 1.6.0
*/
- def finish(reduction: B): O
+ def finish(reduction: BUF): OUT
/**
- * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
+ * Specifies the [[Encoder]] for the intermediate value type.
+ * @since 2.0.0
+ */
+ def bufferEncoder: Encoder[BUF]
+
+ /**
+ * Specifies the [[Encoder]] for the final ouput value type.
+ * @since 2.0.0
+ */
+ def outputEncoder: Encoder[OUT]
+
+ /**
+ * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]].
* operations.
* @since 1.6.0
*/
- def toColumn(
- implicit bEncoder: Encoder[B],
- cEncoder: Encoder[O]): TypedColumn[I, O] = {
+ def toColumn: TypedColumn[IN, OUT] = {
+ implicit val bEncoder = bufferEncoder
+ implicit val cEncoder = outputEncoder
+
val expr =
AggregateExpression(
TypedAggregateExpression(this),
Complete,
isDistinct = false)
- new TypedColumn[I, O](expr, encoderFor[O])
+ new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index 8cb174b906..0e49f871de 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -26,6 +26,7 @@ import org.junit.Test;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
@@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
public void testTypedAggregationAnonClass() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
- Dataset<Tuple2<String, Integer>> agged =
- grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+ Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
- Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
- new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+ Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
Assert.assertEquals(
Arrays.asList(
@@ -73,6 +72,16 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
public Integer finish(Integer reduction) {
return reduction;
}
+
+ @Override
+ public Encoder<Integer> bufferEncoder() {
+ return Encoders.INT();
+ }
+
+ @Override
+ public Encoder<Integer> outputEncoder() {
+ return Encoders.INT();
+ }
}
@Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 08b3389ad9..3a7215ee39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.language.postfixOps
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
@@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
-
override def zero: (Long, Long) = (0, 0)
-
override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
(countAndSum._1 + 1, countAndSum._2 + input._2)
}
-
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
-
override def finish(reduction: (Long, Long)): (Long, Long) = reduction
+ override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
+ override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
}
+
case class AggData(a: Int, b: String)
+
object ClassInputAgg extends Aggregator[AggData, Int, Int] {
- /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0
-
- /**
- * Combine two values to produce a new value. For performance, the function may modify `b` and
- * return it instead of constructing new object for b.
- */
override def reduce(b: Int, a: AggData): Int = b + a.a
-
- /**
- * Transform the output of the reduction.
- */
override def finish(reduction: Int): Int = reduction
-
- /**
- * Merge two intermediate values
- */
override def merge(b1: Int, b2: Int): Int = b1 + b2
+ override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
+
object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
- /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: (Int, AggData) = 0 -> AggData(0, "0")
-
- /**
- * Combine two values to produce a new value. For performance, the function may modify `b` and
- * return it instead of constructing new object for b.
- */
override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
-
- /**
- * Transform the output of the reduction.
- */
override def finish(reduction: (Int, AggData)): Int = reduction._1
-
- /**
- * Merge two intermediate values
- */
override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
(b1._1 + b2._1, b1._2)
+ override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)]
+ override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
+
object NameAgg extends Aggregator[AggData, String, String] {
def zero: String = ""
-
def reduce(b: String, a: AggData): String = a.b + b
-
def merge(b1: String, b2: String): String = b1 + b2
-
def finish(r: String): String = r
+ override def bufferEncoder: Encoder[String] = Encoders.STRING
+ override def outputEncoder: Encoder[String] = Encoders.STRING
+}
+
+
+class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
+ extends Aggregator[IN, OUT, OUT] {
+
+ private val numeric = implicitly[Numeric[OUT]]
+ override def zero: OUT = numeric.zero
+ override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+ override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+ override def finish(reduction: OUT): OUT = reduction
+ override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
+ override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
}
+ test("generic typed sum") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+ checkDataset(
+ ds.groupByKey(_._1)
+ .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn),
+ ("a", 4.0), ("b", 3.0))
+
+ checkDataset(
+ ds.groupByKey(_._1)
+ .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn),
+ ("a", 4), ("b", 3))
+ }
+
test("SPARK-12555 - result should not be corrupted after input columns are reordered") {
val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]