aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-08-18 16:37:25 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-18 16:37:25 +0800
commit1748f824101870b845dbbd118763c6885744f98a (patch)
tree5d82b921bdad119f417c3e8c291dab2f37491f60 /sql/core
parent3e6ef2e8a435a91b6a76876e9833917e5aa0945e (diff)
downloadspark-1748f824101870b845dbbd118763c6885744f98a.tar.gz
spark-1748f824101870b845dbbd118763c6885744f98a.tar.bz2
spark-1748f824101870b845dbbd118763c6885744f98a.zip
[SPARK-16391][SQL] Support partial aggregation for reduceGroups
## What changes were proposed in this pull request? This patch introduces a new private ReduceAggregator interface that is a subclass of Aggregator. ReduceAggregator only requires a single associative and commutative reduce function. ReduceAggregator is also used to implement KeyValueGroupedDataset.reduceGroups in order to support partial aggregation. Note that the pull request was initially done by viirya. ## How was this patch tested? Covered by original tests for reduceGroups, as well as a new test suite for ReduceAggregator. Author: Reynold Xin <rxin@databricks.com> Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #14576 from rxin/reduceAggregator.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala73
3 files changed, 146 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 65a725f3d4..61a3e6e0bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,10 +21,11 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.ReduceAggregator
/**
* :: Experimental ::
@@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
-
- implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
- flatMapGroups(func)
+ val vEncoder = encoderFor[V]
+ val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
+ agg(aggregator)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
new file mode 100644
index 0000000000..174378304d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+/**
+ * An aggregator that uses a single associative and commutative reduce function. This reduce
+ * function can be used to go through all input values and reduces them to a single value.
+ * If there is no input, a null value is returned.
+ *
+ * This class currently assumes there is at least one input row.
+ */
+private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
+ extends Aggregator[T, (Boolean, T), T] {
+
+ private val encoder = implicitly[Encoder[T]]
+
+ override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+
+ override def bufferEncoder: Encoder[(Boolean, T)] =
+ ExpressionEncoder.tuple(
+ ExpressionEncoder[Boolean](),
+ encoder.asInstanceOf[ExpressionEncoder[T]])
+
+ override def outputEncoder: Encoder[T] = encoder
+
+ override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
+ if (b._1) {
+ (true, func(b._2, a))
+ } else {
+ (true, a)
+ }
+ }
+
+ override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
+ if (!b1._1) {
+ b2
+ } else if (!b2._1) {
+ b1
+ } else {
+ (true, func(b1._2, b2._2))
+ }
+ }
+
+ override def finish(reduction: (Boolean, T)): T = {
+ if (!reduction._1) {
+ throw new IllegalStateException("ReduceAggregator requires at least one input row")
+ }
+ reduction._2
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
new file mode 100644
index 0000000000..d826d3f54d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+class ReduceAggregatorSuite extends SparkFunSuite {
+
+ test("zero value") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+ assert(aggregator.zero == (false, null))
+ }
+
+ test("reduce, merge and finish") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ val firstReduce = aggregator.reduce(aggregator.zero, 1)
+ assert(firstReduce == (true, 1))
+
+ val secondReduce = aggregator.reduce(firstReduce, 2)
+ assert(secondReduce == (true, 3))
+
+ val thirdReduce = aggregator.reduce(secondReduce, 3)
+ assert(thirdReduce == (true, 6))
+
+ val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
+ assert(mergeWithZero1 == (true, 1))
+
+ val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
+ assert(mergeWithZero2 == (true, 3))
+
+ val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
+ assert(mergeTwoReduced == (true, 4))
+
+ assert(aggregator.finish(firstReduce)== 1)
+ assert(aggregator.finish(secondReduce) == 3)
+ assert(aggregator.finish(thirdReduce) == 6)
+ assert(aggregator.finish(mergeWithZero1) == 1)
+ assert(aggregator.finish(mergeWithZero2) == 3)
+ assert(aggregator.finish(mergeTwoReduced) == 4)
+ }
+
+ test("requires at least one input row") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ intercept[IllegalStateException] {
+ aggregator.finish(aggregator.zero)
+ }
+ }
+}