diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-12-26 22:10:20 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-12-26 22:10:20 +0800 |
commit | 8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799 (patch) | |
tree | d7ce89b11fb0dbb9becb7fbf2e3358afc34fd0b3 /sql/core/src/main | |
parent | 7026ee23e0a684e13f9d7dfbb8f85e810106d022 (diff) | |
download | spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.gz spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.bz2 spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.zip |
[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate
## What changes were proposed in this pull request?
Currently we implement `Aggregator` with `DeclarativeAggregate`, which will serialize/deserialize the buffer object every time we process an input.
This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to serialize/deserialize buffer object many times. The benchmark shows we get about 2 times speed up.
For simple buffer object that doesn't need serialization, we still go with `DeclarativeAggregate`, to avoid performance regression.
## How was this patch tested?
N/A
Author: Wenchen Fan <wenchen@databricks.com>
Closes #16383 from cloud-fan/aggregator.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 8 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala | 185 |
2 files changed, 164 insertions, 29 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e99d7865bd..a3f581ff27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -75,10 +75,10 @@ class TypedColumn[-T, U]( val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) val newExpr = expr transform { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => - ta.copy( - inputDeserializer = Some(unresolvedDeserializer), - inputClass = Some(inputEncoder.clsTag.runtimeClass), - inputSchema = Some(inputEncoder.schema)) + ta.withInputInfo( + deser = unresolvedDeserializer, + cls = inputEncoder.clsTag.runtimeClass, + schema = inputEncoder.schema) } new TypedColumn[T, U](newExpr, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9911c0b33a..4146bf3269 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ @@ -33,9 +35,6 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = UnresolvedDeserializer( - bufferEncoder.deserializer, - bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { @@ -44,24 +43,78 @@ object TypedAggregateExpression { outputEncoder.schema } - new TypedAggregateExpression( - aggregator.asInstanceOf[Aggregator[Any, Any, Any]], - None, - None, - None, - bufferSerializer, - bufferDeserializer, - outputEncoder.serializer, - outputEncoder.deserializer.dataType, - outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer + // expression is an alias of `BoundReference`, which means the buffer object doesn't need + // serialization. + val isSimpleBuffer = { + bufferSerializer.head match { + case Alias(_: BoundReference, _) if bufferEncoder.flat => true + case _ => false + } + } + + // If the buffer object is simple, use `SimpleTypedAggregateExpression`, which supports whole + // stage codegen. + if (isSimpleBuffer) { + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) + + SimpleTypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + None, + None, + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType, + !outputEncoder.flat || outputEncoder.schema.head.nullable) + } else { + ComplexTypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + None, + None, + bufferSerializer, + bufferEncoder.resolveAndBind().deserializer, + outputEncoder.serializer, + outputType, + !outputEncoder.flat || outputEncoder.schema.head.nullable) + } } } /** * A helper class to hook [[Aggregator]] into the aggregation system. */ -case class TypedAggregateExpression( +trait TypedAggregateExpression extends AggregateFunction { + + def aggregator: Aggregator[Any, Any, Any] + + def inputDeserializer: Option[Expression] + def inputClass: Option[Class[_]] + def inputSchema: Option[StructType] + + def withInputInfo(deser: Expression, cls: Class[_], schema: StructType): TypedAggregateExpression + + override def toString: String = { + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" + } + + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") +} + +// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. + +case class SimpleTypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], inputClass: Option[Class[_]], @@ -71,7 +124,8 @@ case class TypedAggregateExpression( outputSerializer: Seq[Expression], outputExternalType: DataType, dataType: DataType, - nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression { + nullable: Boolean) + extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { override def deterministic: Boolean = true @@ -143,15 +197,96 @@ case class TypedAggregateExpression( } } - override def toString: String = { - val input = inputDeserializer match { - case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString - case Some(deserializer) => deserializer.dataType.simpleString - case _ => "unknown" + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) + } +} + +case class ComplexTypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + inputDeserializer: Option[Expression], + inputClass: Option[Class[_]], + inputSchema: Option[StructType], + bufferSerializer: Seq[NamedExpression], + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + dataType: DataType, + nullable: Boolean, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { + + override def deterministic: Boolean = true + + override def children: Seq[Expression] = inputDeserializer.toSeq + + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved + + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) + + override def createAggregationBuffer(): Any = aggregator.zero + + private lazy val inputRowToObj = GenerateSafeProjection.generate(inputDeserializer.get :: Nil) + + override def update(buffer: Any, input: InternalRow): Any = { + val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any])) + if (inputObj != null) { + aggregator.reduce(buffer, inputObj) + } else { + buffer + } + } + + override def merge(buffer: Any, input: Any): Any = { + aggregator.merge(buffer, input) + } + + private lazy val resultObjToRow = dataType match { + case _: StructType => + UnsafeProjection.create(CreateStruct(outputSerializer)) + case _ => + assert(outputSerializer.length == 1) + UnsafeProjection.create(outputSerializer.head) + } + + override def eval(buffer: Any): Any = { + val resultObj = aggregator.finish(buffer) + if (resultObj == null) { + null + } else { + resultObjToRow(InternalRow(resultObj)).get(0, dataType) } + } - s"$nodeName($input)" + private lazy val bufferObjToRow = UnsafeProjection.create(bufferSerializer) + + override def serialize(buffer: Any): Array[Byte] = { + bufferObjToRow(InternalRow(buffer)).getBytes } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + private lazy val bufferRow = new UnsafeRow(bufferSerializer.length) + private lazy val bufferRowToObject = GenerateSafeProjection.generate(bufferDeserializer :: Nil) + + override def deserialize(storageFormat: Array[Byte]): Any = { + bufferRow.pointTo(storageFormat, storageFormat.length) + bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any])) + } + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset( + newInputAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) + } } |