aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-08-18 16:37:25 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-18 16:37:25 +0800
commit1748f824101870b845dbbd118763c6885744f98a (patch)
tree5d82b921bdad119f417c3e8c291dab2f37491f60 /sql/core/src/test/scala
parent3e6ef2e8a435a91b6a76876e9833917e5aa0945e (diff)
downloadspark-1748f824101870b845dbbd118763c6885744f98a.tar.gz
spark-1748f824101870b845dbbd118763c6885744f98a.tar.bz2
spark-1748f824101870b845dbbd118763c6885744f98a.zip
[SPARK-16391][SQL] Support partial aggregation for reduceGroups
## What changes were proposed in this pull request? This patch introduces a new private ReduceAggregator interface that is a subclass of Aggregator. ReduceAggregator only requires a single associative and commutative reduce function. ReduceAggregator is also used to implement KeyValueGroupedDataset.reduceGroups in order to support partial aggregation. Note that the pull request was initially done by viirya. ## How was this patch tested? Covered by original tests for reduceGroups, as well as a new test suite for ReduceAggregator. Author: Reynold Xin <rxin@databricks.com> Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #14576 from rxin/reduceAggregator.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala73
1 files changed, 73 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
new file mode 100644
index 0000000000..d826d3f54d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+class ReduceAggregatorSuite extends SparkFunSuite {
+
+ test("zero value") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+ assert(aggregator.zero == (false, null))
+ }
+
+ test("reduce, merge and finish") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ val firstReduce = aggregator.reduce(aggregator.zero, 1)
+ assert(firstReduce == (true, 1))
+
+ val secondReduce = aggregator.reduce(firstReduce, 2)
+ assert(secondReduce == (true, 3))
+
+ val thirdReduce = aggregator.reduce(secondReduce, 3)
+ assert(thirdReduce == (true, 6))
+
+ val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
+ assert(mergeWithZero1 == (true, 1))
+
+ val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
+ assert(mergeWithZero2 == (true, 3))
+
+ val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
+ assert(mergeTwoReduced == (true, 4))
+
+ assert(aggregator.finish(firstReduce)== 1)
+ assert(aggregator.finish(secondReduce) == 3)
+ assert(aggregator.finish(thirdReduce) == 6)
+ assert(aggregator.finish(mergeWithZero1) == 1)
+ assert(aggregator.finish(mergeWithZero2) == 3)
+ assert(aggregator.finish(mergeTwoReduced) == 4)
+ }
+
+ test("requires at least one input row") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ intercept[IllegalStateException] {
+ aggregator.finish(aggregator.zero)
+ }
+ }
+}