diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2017-01-03 22:11:54 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-01-03 22:11:54 +0800 |
commit | 52636226dc8cb7fcf00381d65e280d651b25a382 (patch) | |
tree | 5c60b2b7fccd4c3310eeba7e76e16c6935e2f274 | |
parent | e5c307c50a660f706799f1f7f6890bcec888d96b (diff) | |
download | spark-52636226dc8cb7fcf00381d65e280d651b25a382.tar.gz spark-52636226dc8cb7fcf00381d65e280d651b25a382.tar.bz2 spark-52636226dc8cb7fcf00381d65e280d651b25a382.zip |
[SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list
## What changes were proposed in this pull request?
Currently collect_set/collect_list aggregation expression don't support partial aggregation. This patch is to enable partial aggregation for them.
## How was this patch tested?
Jenkins tests.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #16371 from viirya/collect-partial-support.
4 files changed, 39 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2f4d68d179..eaeb010b0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at * the given percentage(s) with value range in [0.0, 1.0]. * - * The operator is bound to the slower sort based aggregation path because the number of elements - * and their partial order cannot be determined in advance. Therefore we have to store all the - * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory - * Errors. + * Because the number of elements and their partial order cannot be determined in advance. + * Therefore we have to store all the elements in memory, and so notice that too many elements can + * cause GC paused and eventually OutOfMemory Errors. * * @param child child expression that produce numeric column value with `child.eval(inputRow)` * @param percentageExpression Expression that represents a single percentage value or an array of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index b176e2a128..411f058510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import scala.collection.generic.Growable import scala.collection.mutable @@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** - * The Collect aggregate function collects all seen expression values into a list of values. + * A base class for collect_list and collect_set aggregate functions. * - * The operator is bound to the slower sort based aggregation path because the number of - * elements (and their memory usage) can not be determined in advance. This also means that the - * collected elements are stored on heap, and that too many elements can cause GC pauses and - * eventually Out of Memory Errors. + * We have to store all the collected elements in memory, and so notice that too many elements + * can cause GC paused and eventually OutOfMemory Errors. */ -abstract class Collect extends ImperativeAggregate { +abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] { val child: Expression @@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate { override def dataType: DataType = ArrayType(child.dataType) - override def supportsPartial: Boolean = false - - override def aggBufferAttributes: Seq[AttributeReference] = Nil - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = Nil - // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. override def deterministic: Boolean = false - protected[this] val buffer: Growable[Any] with Iterable[Any] - - override def initialize(b: InternalRow): Unit = { - buffer.clear() - } + override def update(buffer: T, input: InternalRow): T = { + val value = child.eval(input) - override def update(b: InternalRow, input: InternalRow): Unit = { // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator - val value = child.eval(input) if (value != null) { buffer += value } + buffer } - override def merge(buffer: InternalRow, input: InternalRow): Unit = { - sys.error("Collect cannot be used in partial aggregations.") + override def merge(buffer: T, other: T): T = { + buffer ++= other } - override def eval(input: InternalRow): Any = { + override def eval(buffer: T): Any = { new GenericArrayData(buffer.toArray) } + + private lazy val projection = UnsafeProjection.create( + Array[DataType](ArrayType(elementType = child.dataType, containsNull = false))) + private lazy val row = new UnsafeRow(1) + + override def serialize(obj: T): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes() + } + + override def deserialize(bytes: Array[Byte]): T = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x) + buffer + } } /** @@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate { case class CollectList( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -98,9 +102,9 @@ case class CollectList( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def prettyName: String = "collect_list" + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + override def prettyName: String = "collect_list" } /** @@ -111,7 +115,7 @@ case class CollectList( case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -131,5 +135,5 @@ case class CollectSet( override def prettyName: String = "collect_set" - override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty + override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8e63fba14c..ccd4ae6c2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -458,7 +458,9 @@ abstract class DeclarativeAggregate * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation * buffer's storage format, which is not supported by hash based aggregation. Hash based * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have - * fixed length and can be mutated in place in UnsafeRow) + * fixed length and can be mutated in place in UnsafeRow). + * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in + * hash based aggregation under some constraints. */ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 0b973c3b65..5c1faaecdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest { comparePlans(input, rewrite) } - test("single distinct group with non-partial aggregates") { - val input = testRelation - .groupBy('a, 'd)( - countDistinct('e, 'c).as('agg1), - CollectSet('b).toAggregateExpression().as('agg2)) - .analyze - checkRewrite(RewriteDistinctAggregates(input)) - } - test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) |