aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala26
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala11
6 files changed, 149 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ac05dd3d0e..c459fe5878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -252,6 +252,8 @@ object FunctionRegistry {
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
+ expression[CollectList]("collect_list"),
+ expression[CollectSet]("collect_set"),
// string functions
expression[Ascii]("ascii"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
new file mode 100644
index 0000000000..1f4ff9c4b1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/**
+ * The Collect aggregate function collects all seen expression values into a list of values.
+ *
+ * The operator is bound to the slower sort based aggregation path because the number of
+ * elements (and their memory usage) can not be determined in advance. This also means that the
+ * collected elements are stored on heap, and that too many elements can cause GC pauses and
+ * eventually Out of Memory Errors.
+ */
+abstract class Collect extends ImperativeAggregate {
+
+ val child: Expression
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(child.dataType)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ override def supportsPartial: Boolean = false
+
+ override def aggBufferAttributes: Seq[AttributeReference] = Nil
+
+ override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
+
+ protected[this] val buffer: Growable[Any] with Iterable[Any]
+
+ override def initialize(b: MutableRow): Unit = {
+ buffer.clear()
+ }
+
+ override def update(b: MutableRow, input: InternalRow): Unit = {
+ buffer += child.eval(input)
+ }
+
+ override def merge(buffer: MutableRow, input: InternalRow): Unit = {
+ sys.error("Collect cannot be used in partial aggregations.")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ new GenericArrayData(buffer.toArray)
+ }
+}
+
+/**
+ * Collect a list of elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
+case class CollectList(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_list"
+
+ override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+}
+
+/**
+ * Collect a list of unique elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
+case class CollectSet(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_set"
+
+ override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e295c20b6..07f55042ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -195,18 +195,14 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_list(e: Column): Column = callUDF("collect_list", e)
+ def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) }
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
@@ -215,18 +211,14 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_set(e: Column): Column = callUDF("collect_set", e)
+ def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) }
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8a99866a33..69a990789b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null))
}
+ test("collect functions") {
+ val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+ )
+ }
+
+ test("collect functions structs") {
+ val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
+ .toDF("a", "x", "y")
+ .select($"a", struct($"x", $"y").as("b"))
+ checkAnswer(
+ df.select(collect_list($"a"), sort_array(collect_list($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), sort_array(collect_set($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
+ )
+ }
+
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 75a252ccba..4f8aac8c2f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog(
}
}
}
-
- // Pre-load a few commonly used Hive built-in functions.
- HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
- case (functionName, clazz) =>
- val builder = makeFunctionBuilder(functionName, clazz)
- val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
- createTempFunction(functionName, info, builder, ignoreIfExists = false)
- }
-}
-
-private[sql] object HiveSessionCatalog {
- // This is the list of Hive's built-in functions that are commonly used and we want to
- // pre-load when we create the FunctionRegistry.
- val preloadedHiveBuiltinFunctions =
- ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
- ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 57f96e725a..cc41c04c71 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
)
}
- test("collect functions") {
- checkAnswer(
- testData.select(collect_list($"a"), collect_list($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
- )
- checkAnswer(
- testData.select(collect_set($"a"), collect_set($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
- )
- }
-
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),