diff options
author | Reynold Xin <rxin@databricks.com> | 2016-04-09 00:00:39 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-09 00:00:39 -0700 |
commit | 520dde48d0d52dbbbbe1710a3275fdd5355dd69d (patch) | |
tree | c6fdf831b43c044e894f5cf24af30650e6aa82c0 /sql/core/src/test | |
parent | 2f0b882e5c8787b09bedcc8208e6dcc5662dbbab (diff) | |
download | spark-520dde48d0d52dbbbbe1710a3275fdd5355dd69d.tar.gz spark-520dde48d0d52dbbbbe1710a3275fdd5355dd69d.tar.bz2 spark-520dde48d0d52dbbbbe1710a3275fdd5355dd69d.zip |
[SPARK-14451][SQL] Move encoder definition into Aggregator interface
## What changes were proposed in this pull request?
When we first introduced Aggregators, we required the user of Aggregators to (implicitly) specify the encoders. It would actually make more sense to have the encoders be specified by the implementation of Aggregators, since each implementation should have the most state about how to encode its own data type.
Note that this simplifies the Java API because Java users no longer need to explicitly specify encoders for aggregators.
## How was this patch tested?
Updated unit tests.
Author: Reynold Xin <rxin@databricks.com>
Closes #12231 from rxin/SPARK-14451.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java | 17 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 75 |
2 files changed, 53 insertions, 39 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index 8cb174b906..0e49f871de 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -26,6 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; @@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { public void testTypedAggregationAnonClass() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Integer>> agged = - grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn()); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset<Tuple2<String, Integer>> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( @@ -73,6 +72,16 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { public Integer finish(Integer reduction) { return reduction; } + + @Override + public Encoder<Integer> bufferEncoder() { + return Encoders.INT(); + } + + @Override + public Encoder<Integer> outputEncoder() { + return Encoders.INT(); + } } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 08b3389ad9..3a7215ee39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { - override def zero: (Long, Long) = (0, 0) - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { (countAndSum._1 + 1, countAndSum._2 + input._2) } - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } - override def finish(reduction: (Long, Long)): (Long, Long) = reduction + override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] + override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] } + case class AggData(a: Int, b: String) + object ClassInputAgg extends Aggregator[AggData, Int, Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: Int, a: AggData): Int = b + a.a - - /** - * Transform the output of the reduction. - */ override def finish(reduction: Int): Int = reduction - - /** - * Merge two intermediate values - */ override def merge(b1: Int, b2: Int): Int = b1 + b2 + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: (Int, AggData) = 0 -> AggData(0, "0") - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) - - /** - * Transform the output of the reduction. - */ override def finish(reduction: (Int, AggData)): Int = reduction._1 - - /** - * Merge two intermediate values - */ override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = (b1._1 + b2._1, b1._2) + override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" - def reduce(b: String, a: AggData): String = a.b + b - def merge(b1: String, b2: String): String = b1 + b2 - def finish(r: String): String = r + override def bufferEncoder: Encoder[String] = Encoders.STRING + override def outputEncoder: Encoder[String] = Encoders.STRING +} + + +class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) + extends Aggregator[IN, OUT, OUT] { + + private val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction + override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] + override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) } + test("generic typed sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn), + ("a", 4.0), ("b", 3.0)) + + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn), + ("a", 4), ("b", 3)) + } + test("SPARK-12555 - result should not be corrupted after input columns are reordered") { val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] |