aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2017-01-03 22:11:54 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-03 22:11:54 +0800
commit52636226dc8cb7fcf00381d65e280d651b25a382 (patch)
tree5c60b2b7fccd4c3310eeba7e76e16c6935e2f274
parente5c307c50a660f706799f1f7f6890bcec888d96b (diff)
downloadspark-52636226dc8cb7fcf00381d65e280d651b25a382.tar.gz
spark-52636226dc8cb7fcf00381d65e280d651b25a382.tar.bz2
spark-52636226dc8cb7fcf00381d65e280d651b25a382.zip
[SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list
## What changes were proposed in this pull request? Currently collect_set/collect_list aggregation expression don't support partial aggregation. This patch is to enable partial aggregation for them. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #16371 from viirya/collect-partial-support.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala62
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala9
4 files changed, 39 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 2f4d68d179..eaeb010b0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap
* The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
* the given percentage(s) with value range in [0.0, 1.0].
*
- * The operator is bound to the slower sort based aggregation path because the number of elements
- * and their partial order cannot be determined in advance. Therefore we have to store all the
- * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
- * Errors.
+ * Because the number of elements and their partial order cannot be determined in advance.
+ * Therefore we have to store all the elements in memory, and so notice that too many elements can
+ * cause GC paused and eventually OutOfMemory Errors.
*
* @param child child expression that produce numeric column value with `child.eval(inputRow)`
* @param percentageExpression Expression that represents a single percentage value or an array of
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index b176e2a128..411f058510 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
import scala.collection.generic.Growable
import scala.collection.mutable
@@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
/**
- * The Collect aggregate function collects all seen expression values into a list of values.
+ * A base class for collect_list and collect_set aggregate functions.
*
- * The operator is bound to the slower sort based aggregation path because the number of
- * elements (and their memory usage) can not be determined in advance. This also means that the
- * collected elements are stored on heap, and that too many elements can cause GC pauses and
- * eventually Out of Memory Errors.
+ * We have to store all the collected elements in memory, and so notice that too many elements
+ * can cause GC paused and eventually OutOfMemory Errors.
*/
-abstract class Collect extends ImperativeAggregate {
+abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] {
val child: Expression
@@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate {
override def dataType: DataType = ArrayType(child.dataType)
- override def supportsPartial: Boolean = false
-
- override def aggBufferAttributes: Seq[AttributeReference] = Nil
-
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
-
// Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
// actual order of input rows.
override def deterministic: Boolean = false
- protected[this] val buffer: Growable[Any] with Iterable[Any]
-
- override def initialize(b: InternalRow): Unit = {
- buffer.clear()
- }
+ override def update(buffer: T, input: InternalRow): T = {
+ val value = child.eval(input)
- override def update(b: InternalRow, input: InternalRow): Unit = {
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
- val value = child.eval(input)
if (value != null) {
buffer += value
}
+ buffer
}
- override def merge(buffer: InternalRow, input: InternalRow): Unit = {
- sys.error("Collect cannot be used in partial aggregations.")
+ override def merge(buffer: T, other: T): T = {
+ buffer ++= other
}
- override def eval(input: InternalRow): Any = {
+ override def eval(buffer: T): Any = {
new GenericArrayData(buffer.toArray)
}
+
+ private lazy val projection = UnsafeProjection.create(
+ Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
+ private lazy val row = new UnsafeRow(1)
+
+ override def serialize(obj: T): Array[Byte] = {
+ val array = new GenericArrayData(obj.toArray)
+ projection.apply(InternalRow.apply(array)).getBytes()
+ }
+
+ override def deserialize(bytes: Array[Byte]): T = {
+ val buffer = createAggregationBuffer()
+ row.pointTo(bytes, bytes.length)
+ row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+ buffer
+ }
}
/**
@@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate {
case class CollectList(
child: Expression,
mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0) extends Collect {
+ inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
def this(child: Expression) = this(child, 0, 0)
@@ -98,9 +102,9 @@ case class CollectList(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- override def prettyName: String = "collect_list"
+ override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
- override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+ override def prettyName: String = "collect_list"
}
/**
@@ -111,7 +115,7 @@ case class CollectList(
case class CollectSet(
child: Expression,
mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0) extends Collect {
+ inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
def this(child: Expression) = this(child, 0, 0)
@@ -131,5 +135,5 @@ case class CollectSet(
override def prettyName: String = "collect_set"
- override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
+ override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 8e63fba14c..ccd4ae6c2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -458,7 +458,9 @@ abstract class DeclarativeAggregate
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
- * fixed length and can be mutated in place in UnsafeRow)
+ * fixed length and can be mutated in place in UnsafeRow).
+ * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in
+ * hash based aggregation under some constraints.
*/
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 0b973c3b65..5c1faaecdb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
comparePlans(input, rewrite)
}
- test("single distinct group with non-partial aggregates") {
- val input = testRelation
- .groupBy('a, 'd)(
- countDistinct('e, 'c).as('agg1),
- CollectSet('b).toAggregateExpression().as('agg2))
- .analyze
- checkRewrite(RewriteDistinctAggregates(input))
- }
-
test("multiple distinct groups") {
val input = testRelation
.groupBy('a)(countDistinct('b, 'c), countDistinct('d))