aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-12-26 22:10:20 +0800
committerWenchen Fan <wenchen@databricks.com>2016-12-26 22:10:20 +0800
commit8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799 (patch)
treed7ce89b11fb0dbb9becb7fbf2e3358afc34fd0b3 /sql/core/src/main
parent7026ee23e0a684e13f9d7dfbb8f85e810106d022 (diff)
downloadspark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.gz
spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.bz2
spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.zip
[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate
## What changes were proposed in this pull request? Currently we implement `Aggregator` with `DeclarativeAggregate`, which will serialize/deserialize the buffer object every time we process an input. This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to serialize/deserialize buffer object many times. The benchmark shows we get about 2 times speed up. For simple buffer object that doesn't need serialization, we still go with `DeclarativeAggregate`, to avoid performance regression. ## How was this patch tested? N/A Author: Wenchen Fan <wenchen@databricks.com> Closes #16383 from cloud-fan/aggregator.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala185
2 files changed, 164 insertions, 29 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e99d7865bd..a3f581ff27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -75,10 +75,10 @@ class TypedColumn[-T, U](
val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
- ta.copy(
- inputDeserializer = Some(unresolvedDeserializer),
- inputClass = Some(inputEncoder.clsTag.runtimeClass),
- inputSchema = Some(inputEncoder.schema))
+ ta.withInputInfo(
+ deser = unresolvedDeserializer,
+ cls = inputEncoder.clsTag.runtimeClass,
+ schema = inputEncoder.schema)
}
new TypedColumn[T, U](newExpr, encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 9911c0b33a..4146bf3269 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
@@ -33,9 +35,6 @@ object TypedAggregateExpression {
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
val bufferEncoder = encoderFor[BUF]
val bufferSerializer = bufferEncoder.namedExpressions
- val bufferDeserializer = UnresolvedDeserializer(
- bufferEncoder.deserializer,
- bufferSerializer.map(_.toAttribute))
val outputEncoder = encoderFor[OUT]
val outputType = if (outputEncoder.flat) {
@@ -44,24 +43,78 @@ object TypedAggregateExpression {
outputEncoder.schema
}
- new TypedAggregateExpression(
- aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
- None,
- None,
- None,
- bufferSerializer,
- bufferDeserializer,
- outputEncoder.serializer,
- outputEncoder.deserializer.dataType,
- outputType,
- !outputEncoder.flat || outputEncoder.schema.head.nullable)
+ // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
+ // expression is an alias of `BoundReference`, which means the buffer object doesn't need
+ // serialization.
+ val isSimpleBuffer = {
+ bufferSerializer.head match {
+ case Alias(_: BoundReference, _) if bufferEncoder.flat => true
+ case _ => false
+ }
+ }
+
+ // If the buffer object is simple, use `SimpleTypedAggregateExpression`, which supports whole
+ // stage codegen.
+ if (isSimpleBuffer) {
+ val bufferDeserializer = UnresolvedDeserializer(
+ bufferEncoder.deserializer,
+ bufferSerializer.map(_.toAttribute))
+
+ SimpleTypedAggregateExpression(
+ aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+ None,
+ None,
+ None,
+ bufferSerializer,
+ bufferDeserializer,
+ outputEncoder.serializer,
+ outputEncoder.deserializer.dataType,
+ outputType,
+ !outputEncoder.flat || outputEncoder.schema.head.nullable)
+ } else {
+ ComplexTypedAggregateExpression(
+ aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+ None,
+ None,
+ None,
+ bufferSerializer,
+ bufferEncoder.resolveAndBind().deserializer,
+ outputEncoder.serializer,
+ outputType,
+ !outputEncoder.flat || outputEncoder.schema.head.nullable)
+ }
}
}
/**
* A helper class to hook [[Aggregator]] into the aggregation system.
*/
-case class TypedAggregateExpression(
+trait TypedAggregateExpression extends AggregateFunction {
+
+ def aggregator: Aggregator[Any, Any, Any]
+
+ def inputDeserializer: Option[Expression]
+ def inputClass: Option[Class[_]]
+ def inputSchema: Option[StructType]
+
+ def withInputInfo(deser: Expression, cls: Class[_], schema: StructType): TypedAggregateExpression
+
+ override def toString: String = {
+ val input = inputDeserializer match {
+ case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString
+ case Some(deserializer) => deserializer.dataType.simpleString
+ case _ => "unknown"
+ }
+
+ s"$nodeName($input)"
+ }
+
+ override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
+}
+
+// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.
+
+case class SimpleTypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
inputDeserializer: Option[Expression],
inputClass: Option[Class[_]],
@@ -71,7 +124,8 @@ case class TypedAggregateExpression(
outputSerializer: Seq[Expression],
outputExternalType: DataType,
dataType: DataType,
- nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression {
+ nullable: Boolean)
+ extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression {
override def deterministic: Boolean = true
@@ -143,15 +197,96 @@ case class TypedAggregateExpression(
}
}
- override def toString: String = {
- val input = inputDeserializer match {
- case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString
- case Some(deserializer) => deserializer.dataType.simpleString
- case _ => "unknown"
+ override def withInputInfo(
+ deser: Expression,
+ cls: Class[_],
+ schema: StructType): TypedAggregateExpression = {
+ copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema))
+ }
+}
+
+case class ComplexTypedAggregateExpression(
+ aggregator: Aggregator[Any, Any, Any],
+ inputDeserializer: Option[Expression],
+ inputClass: Option[Class[_]],
+ inputSchema: Option[StructType],
+ bufferSerializer: Seq[NamedExpression],
+ bufferDeserializer: Expression,
+ outputSerializer: Seq[Expression],
+ dataType: DataType,
+ nullable: Boolean,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression {
+
+ override def deterministic: Boolean = true
+
+ override def children: Seq[Expression] = inputDeserializer.toSeq
+
+ override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved
+
+ override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
+
+ override def createAggregationBuffer(): Any = aggregator.zero
+
+ private lazy val inputRowToObj = GenerateSafeProjection.generate(inputDeserializer.get :: Nil)
+
+ override def update(buffer: Any, input: InternalRow): Any = {
+ val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any]))
+ if (inputObj != null) {
+ aggregator.reduce(buffer, inputObj)
+ } else {
+ buffer
+ }
+ }
+
+ override def merge(buffer: Any, input: Any): Any = {
+ aggregator.merge(buffer, input)
+ }
+
+ private lazy val resultObjToRow = dataType match {
+ case _: StructType =>
+ UnsafeProjection.create(CreateStruct(outputSerializer))
+ case _ =>
+ assert(outputSerializer.length == 1)
+ UnsafeProjection.create(outputSerializer.head)
+ }
+
+ override def eval(buffer: Any): Any = {
+ val resultObj = aggregator.finish(buffer)
+ if (resultObj == null) {
+ null
+ } else {
+ resultObjToRow(InternalRow(resultObj)).get(0, dataType)
}
+ }
- s"$nodeName($input)"
+ private lazy val bufferObjToRow = UnsafeProjection.create(bufferSerializer)
+
+ override def serialize(buffer: Any): Array[Byte] = {
+ bufferObjToRow(InternalRow(buffer)).getBytes
}
- override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
+ private lazy val bufferRow = new UnsafeRow(bufferSerializer.length)
+ private lazy val bufferRowToObject = GenerateSafeProjection.generate(bufferDeserializer :: Nil)
+
+ override def deserialize(storageFormat: Array[Byte]): Any = {
+ bufferRow.pointTo(storageFormat, storageFormat.length)
+ bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any]))
+ }
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): ComplexTypedAggregateExpression =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def withInputInfo(
+ deser: Expression,
+ cls: Class[_],
+ schema: StructType): TypedAggregateExpression = {
+ copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema))
+ }
}