aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-05-12 13:56:00 -0700
committerReynold Xin <rxin@databricks.com>2016-05-12 13:56:00 -0700
commitbb1362eb3b36b553dca246b95f59ba7fd8adcc8a (patch)
treed47608b04e9bc54f6a77cf6648f157982aa8788a
parenta57aadae84aca27e5f02ac0bd64fd0ea34a64b61 (diff)
downloadspark-bb1362eb3b36b553dca246b95f59ba7fd8adcc8a.tar.gz
spark-bb1362eb3b36b553dca246b95f59ba7fd8adcc8a.tar.bz2
spark-bb1362eb3b36b553dca246b95f59ba7fd8adcc8a.zip
[SPARK-10605][SQL] Create native collect_list/collect_set aggregates
## What changes were proposed in this pull request? We currently use the Hive implementations for the collect_list/collect_set aggregate functions. This has a few major drawbacks: the use of HiveUDAF (which has quite a bit of overhead) and the lack of support for struct datatypes. This PR adds native implementation of these functions to Spark. The size of the collected list/set may vary, this means we cannot use the fast, Tungsten, aggregation path to perform the aggregation, and that we fallback to the slower sort based path. Another big issue with these operators is that when the size of the collected list/set grows too large, we can start experiencing large GC pauzes and OOMEs. This `collect*` aggregates implemented in this PR rely on the sort based aggregate path for correctness. They maintain their own internal buffer which holds the rows for one group at a time. The sortbased aggregation path is triggered by disabling `partialAggregation` for these aggregates (which is kinda funny); this technique is also employed in `org.apache.spark.sql.hiveHiveUDAFFunction`. I have done some performance testing: ```scala import org.apache.spark.sql.{Dataset, Row} sql("create function collect_list2 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'") val df = range(0, 10000000).select($"id", (rand(213123L) * 100000).cast("int").as("grp")) df.select(countDistinct($"grp")).show def benchmark(name: String, plan: Dataset[Row], maxItr: Int = 5): Unit = { // Do not measure planning. plan1.queryExecution.executedPlan // Execute the plan a number of times and average the result. val start = System.nanoTime var i = 0 while (i < maxItr) { plan.rdd.foreach(row => Unit) i += 1 } val time = (System.nanoTime - start) / (maxItr * 1000000L) println(s"[$name] $maxItr iterations completed in an average time of $time ms.") } val plan1 = df.groupBy($"grp").agg(collect_list($"id")) val plan2 = df.groupBy($"grp").agg(callUDF("collect_list2", $"id")) benchmark("Spark collect_list", plan1) ... > [Spark collect_list] 5 iterations completed in an average time of 3371 ms. benchmark("Hive collect_list", plan2) ... > [Hive collect_list] 5 iterations completed in an average time of 9109 ms. ``` Performance is improved by a factor 2-3. ## How was this patch tested? Added tests to `DataFrameAggregateSuite`. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12874 from hvanhovell/implode.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala11
6 files changed, 149 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ac05dd3d0e..c459fe5878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -252,6 +252,8 @@ object FunctionRegistry {
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
+ expression[CollectList]("collect_list"),
+ expression[CollectSet]("collect_set"),
// string functions
expression[Ascii]("ascii"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
new file mode 100644
index 0000000000..1f4ff9c4b1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/**
+ * The Collect aggregate function collects all seen expression values into a list of values.
+ *
+ * The operator is bound to the slower sort based aggregation path because the number of
+ * elements (and their memory usage) can not be determined in advance. This also means that the
+ * collected elements are stored on heap, and that too many elements can cause GC pauses and
+ * eventually Out of Memory Errors.
+ */
+abstract class Collect extends ImperativeAggregate {
+
+ val child: Expression
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(child.dataType)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ override def supportsPartial: Boolean = false
+
+ override def aggBufferAttributes: Seq[AttributeReference] = Nil
+
+ override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
+
+ protected[this] val buffer: Growable[Any] with Iterable[Any]
+
+ override def initialize(b: MutableRow): Unit = {
+ buffer.clear()
+ }
+
+ override def update(b: MutableRow, input: InternalRow): Unit = {
+ buffer += child.eval(input)
+ }
+
+ override def merge(buffer: MutableRow, input: InternalRow): Unit = {
+ sys.error("Collect cannot be used in partial aggregations.")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ new GenericArrayData(buffer.toArray)
+ }
+}
+
+/**
+ * Collect a list of elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
+case class CollectList(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_list"
+
+ override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+}
+
+/**
+ * Collect a list of unique elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
+case class CollectSet(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_set"
+
+ override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e295c20b6..07f55042ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -195,18 +195,14 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_list(e: Column): Column = callUDF("collect_list", e)
+ def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) }
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
@@ -215,18 +211,14 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_set(e: Column): Column = callUDF("collect_set", e)
+ def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) }
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8a99866a33..69a990789b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null))
}
+ test("collect functions") {
+ val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+ )
+ }
+
+ test("collect functions structs") {
+ val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
+ .toDF("a", "x", "y")
+ .select($"a", struct($"x", $"y").as("b"))
+ checkAnswer(
+ df.select(collect_list($"a"), sort_array(collect_list($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), sort_array(collect_set($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
+ )
+ }
+
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 75a252ccba..4f8aac8c2f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog(
}
}
}
-
- // Pre-load a few commonly used Hive built-in functions.
- HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
- case (functionName, clazz) =>
- val builder = makeFunctionBuilder(functionName, clazz)
- val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
- createTempFunction(functionName, info, builder, ignoreIfExists = false)
- }
-}
-
-private[sql] object HiveSessionCatalog {
- // This is the list of Hive's built-in functions that are commonly used and we want to
- // pre-load when we create the FunctionRegistry.
- val preloadedHiveBuiltinFunctions =
- ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
- ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 57f96e725a..cc41c04c71 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
)
}
- test("collect functions") {
- checkAnswer(
- testData.select(collect_list($"a"), collect_list($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
- )
- checkAnswer(
- testData.select(collect_set($"a"), collect_set($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
- )
- }
-
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),