aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-15 12:10:00 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-15 12:10:00 +0800
commit297ba3f1b49cc37d9891a529142c553e0a5e2d62 (patch)
tree2a61d490100de8b609a15fb52561524dddaca0e8 /sql
parentb5c60bcdca3bcace607b204a6c196a5386e8a896 (diff)
downloadspark-297ba3f1b49cc37d9891a529142c553e0a5e2d62.tar.gz
spark-297ba3f1b49cc37d9891a529142c553e0a5e2d62.tar.bz2
spark-297ba3f1b49cc37d9891a529142c553e0a5e2d62.zip
[SPARK-14275][SQL] Reimplement TypedAggregateExpression to DeclarativeAggregate
## What changes were proposed in this pull request? `ExpressionEncoder` is just a container for serialization and deserialization expressions, we can use these expressions to build `TypedAggregateExpression` directly, so that it can fit in `DeclarativeAggregate`, which is more efficient. One trick is, for each buffer serializer expression, it will reference to the result object of serialization and function call. To avoid re-calculating this result object, we can serialize the buffer object to a single struct field, so that we can use a special `Expression` to only evaluate result object once. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #12067 from cloud-fan/typed_udaf.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala192
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala112
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala14
12 files changed, 303 insertions, 130 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a24a5db8d4..718bb4b118 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns a user-facing string representation of this expression's name.
* This should usually match the name of the function in SQL.
*/
- def prettyName: String = getClass.getSimpleName.toLowerCase
+ def prettyName: String = nodeName.toLowerCase
private def flatArguments = productIterator.flatMap {
case t: Traversable[_] => t
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
new file mode 100644
index 0000000000..22645c952e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A special expression that evaluates [[BoundReference]]s by given expressions instead of the
+ * input row.
+ *
+ * @param result The expression that contains [[BoundReference]] and produces the final output.
+ * @param children The expressions that used as input values for [[BoundReference]].
+ */
+case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
+ extends Expression {
+
+ override def nullable: Boolean = result.nullable
+ override def dataType: DataType = result.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (result.references.nonEmpty) {
+ return TypeCheckFailure("The result expression cannot reference to any attributes.")
+ }
+
+ var maxOrdinal = -1
+ result foreach {
+ case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
+ }
+ if (maxOrdinal > children.length) {
+ return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
+ s"there are only ${children.length} inputs.")
+ }
+
+ TypeCheckSuccess
+ }
+
+ private lazy val projection = UnsafeProjection.create(children)
+
+ override def eval(input: InternalRow): Any = {
+ result.eval(projection(input))
+ }
+
+ override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ val childrenGen = children.map(_.gen(ctx))
+ val childrenVars = childrenGen.zip(children).map {
+ case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
+ }
+
+ val resultGen = result.transform {
+ case b: BoundReference => childrenVars(b.ordinal)
+ }.gen(ctx)
+
+ ev.value = resultGen.value
+ ev.isNull = resultGen.isNull
+
+ childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e6804d096c..7fd4bc3066 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -60,7 +60,8 @@ object Literal {
* Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object
* into code generation.
*/
- def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass))
+ def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType)
+ def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass))
def fromJSON(json: JValue): Literal = {
val dataType = DataType.parseDataType(json \ "dataType")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
index 06ee0fbfe9..b7b1acc582 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType {
throw new UnsupportedOperationException("No size estimation available for objects.")
def asNullable: DataType = this
+
+ override def simpleString: String = cls.getName
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d64736e111..bd96941da7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -59,14 +59,14 @@ class TypedColumn[-T, U](
* on a decoded object.
*/
private[sql] def withInputType(
- inputEncoder: ExpressionEncoder[_],
- schema: Seq[Attribute]): TypedColumn[T, U] = {
- val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
- new TypedColumn[T, U](
- expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
- ta.copy(aEncoder = Some(boundEncoder), children = schema)
- },
- encoder)
+ inputDeserializer: Expression,
+ inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
+ val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
+ val newExpr = expr transform {
+ case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
+ ta.copy(inputDeserializer = Some(unresolvedDeserializer))
+ }
+ new TypedColumn[T, U](newExpr, encoder)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e216945fbe..4edc90d9c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -992,7 +992,7 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- boundTEncoder,
+ unresolvedTEncoder.deserializer,
logicalPlan.output).named :: Nil,
logicalPlan),
implicitly[Encoder[U1]])
@@ -1006,7 +1006,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
+ columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index f19ad6e707..05e13e66d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(
- _.withInputType(resolvedVEncoder, dataAttributes).named)
+ columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named)
val keyColumn = if (resolvedKEncoder.flat) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 9abae53579..535e64cb34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
object TypedAggregateExpression {
- def apply[A, B : Encoder, C : Encoder](
- aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
+ def apply[BUF : Encoder, OUT : Encoder](
+ aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
+ val bufferEncoder = encoderFor[BUF]
+ // We will insert the deserializer and function call expression at the bottom of each serializer
+ // expression while executing `TypedAggregateExpression`, which means multiply serializer
+ // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating,
+ // here we always use one single serializer expression to serialize the buffer object into a
+ // single-field row, no matter whether the encoder is flat or not. We also need to update the
+ // deserializer to read in all fields from that single-field row.
+ // TODO: remove this trick after we have better integration of subexpression elimination and
+ // whole stage codegen.
+ val bufferSerializer = if (bufferEncoder.flat) {
+ bufferEncoder.namedExpressions.head
+ } else {
+ Alias(CreateStruct(bufferEncoder.serializer), "buffer")()
+ }
+
+ val bufferDeserializer = if (bufferEncoder.flat) {
+ bufferEncoder.deserializer transformUp {
+ case b: BoundReference => bufferSerializer.toAttribute
+ }
+ } else {
+ bufferEncoder.deserializer transformUp {
+ case UnresolvedAttribute(nameParts) =>
+ assert(nameParts.length == 1)
+ UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head))
+ case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal)
+ }
+ }
+
+ val outputEncoder = encoderFor[OUT]
+ val outputType = if (outputEncoder.flat) {
+ outputEncoder.schema.head.dataType
+ } else {
+ outputEncoder.schema
+ }
+
new TypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
None,
- encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
- encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
- Nil,
- 0,
- 0)
+ bufferSerializer,
+ bufferDeserializer,
+ outputEncoder.serializer,
+ outputEncoder.deserializer.dataType,
+ outputType)
}
}
/**
- * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
- * the following limitations:
- * - It assumes the aggregator has a zero, `0`.
+ * A helper class to hook [[Aggregator]] into the aggregation system.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
- aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
- unresolvedBEncoder: ExpressionEncoder[Any],
- cEncoder: ExpressionEncoder[Any],
- children: Seq[Attribute],
- mutableAggBufferOffset: Int,
- inputAggBufferOffset: Int)
- extends ImperativeAggregate with Logging {
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
+ inputDeserializer: Option[Expression],
+ bufferSerializer: NamedExpression,
+ bufferDeserializer: Expression,
+ outputSerializer: Seq[Expression],
+ outputExternalType: DataType,
+ dataType: DataType) extends DeclarativeAggregate with NonSQLExpression {
override def nullable: Boolean = true
- override def dataType: DataType = if (cEncoder.flat) {
- cEncoder.schema.head.dataType
- } else {
- cEncoder.schema
- }
-
override def deterministic: Boolean = true
- override lazy val resolved: Boolean = aEncoder.isDefined
-
- override lazy val inputTypes: Seq[DataType] = Nil
+ override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer
- override val aggBufferSchema: StructType = unresolvedBEncoder.schema
+ override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved
- override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+ override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
- val bEncoder = unresolvedBEncoder
- .resolve(aggBufferAttributes, OuterScopes.outerScopes)
- .bind(aggBufferAttributes)
+ override def inputTypes: Seq[AbstractDataType] = Nil
- // Note: although this simply copies aggBufferAttributes, this common code can not be placed
- // in the superclass because that will lead to initialization ordering issues.
- override val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
+ private def aggregatorLiteral =
+ Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]]))
- // We let the dataset do the binding for us.
- lazy val boundA = aEncoder.get
+ private def bufferExternalType = bufferDeserializer.dataType
- private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
- var i = 0
- while (i < aggBufferAttributes.length) {
- val offset = mutableAggBufferOffset + i
- aggBufferSchema(i).dataType match {
- case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
- case ByteType => buffer.setByte(offset, value.getByte(i))
- case ShortType => buffer.setShort(offset, value.getShort(i))
- case IntegerType => buffer.setInt(offset, value.getInt(i))
- case LongType => buffer.setLong(offset, value.getLong(i))
- case FloatType => buffer.setFloat(offset, value.getFloat(i))
- case DoubleType => buffer.setDouble(offset, value.getDouble(i))
- case other => buffer.update(offset, value.get(i, other))
- }
- i += 1
- }
- }
+ override lazy val aggBufferAttributes: Seq[AttributeReference] =
+ bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil
- override def initialize(buffer: MutableRow): Unit = {
- val zero = bEncoder.toRow(aggregator.zero)
- updateBuffer(buffer, zero)
+ override lazy val initialValues: Seq[Expression] = {
+ val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
+ ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil
}
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val inputA = boundA.fromRow(input)
- val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
- val merged = aggregator.reduce(currentB, inputA)
- val returned = bEncoder.toRow(merged)
+ override lazy val updateExpressions: Seq[Expression] = {
+ val reduced = Invoke(
+ aggregatorLiteral,
+ "reduce",
+ bufferExternalType,
+ bufferDeserializer :: inputDeserializer.get :: Nil)
- updateBuffer(buffer, returned)
+ ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil
}
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
- val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
- val merged = aggregator.merge(b1, b2)
- val returned = bEncoder.toRow(merged)
+ override lazy val mergeExpressions: Seq[Expression] = {
+ val leftBuffer = bufferDeserializer transform {
+ case a: AttributeReference => a.left
+ }
+ val rightBuffer = bufferDeserializer transform {
+ case a: AttributeReference => a.right
+ }
+ val merged = Invoke(
+ aggregatorLiteral,
+ "merge",
+ bufferExternalType,
+ leftBuffer :: rightBuffer :: Nil)
- updateBuffer(buffer1, returned)
+ ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil
}
- override def eval(buffer: InternalRow): Any = {
- val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
- val result = cEncoder.toRow(aggregator.finish(b))
+ override lazy val evaluateExpression: Expression = {
+ val resultObj = Invoke(
+ aggregatorLiteral,
+ "finish",
+ outputExternalType,
+ bufferDeserializer :: Nil)
+
dataType match {
- case _: StructType => result
- case _ => result.get(0, dataType)
+ case s: StructType =>
+ ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil)
+ case _ =>
+ assert(outputSerializer.length == 1)
+ outputSerializer.head transform {
+ case b: BoundReference => resultObj
+ }
}
}
override def toString: String = {
- s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
+ val input = inputDeserializer match {
+ case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString
+ case Some(deserializer) => deserializer.dataType.simpleString
+ case _ => "unknown"
+ }
+
+ s"$nodeName($input)"
}
- override def nodeName: String = aggregator.getClass.getSimpleName
+ override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 7da8379c9a..baae9dd2d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn}
+import org.apache.spark.sql.{Dataset, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
/**
- * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]]
- * operations to take all of the elements of a group and reduce them to a single value.
+ * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take
+ * all of the elements of a group and reduce them to a single value.
*
* For example, the following aggregator extracts an `int` from a specific class and adds them up:
* {{{
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 5f3dd906fe..ae9fb80c68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -18,6 +18,9 @@
package org.apache.spark.sql
import org.apache.spark.SparkContext
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.Benchmark
@@ -33,16 +36,17 @@ object DatasetBenchmark {
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val benchmark = new Benchmark("back-to-back map", numRows)
-
val func = (d: Data) => Data(d.l + 1, d.s)
- benchmark.addCase("Dataset") { iter =>
- var res = df.as[Data]
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
var i = 0
while (i < numChains) {
- res = res.map(func)
+ res = rdd.map(func)
i += 1
}
- res.queryExecution.toRdd.foreach(_ => Unit)
+ res.foreach(_ => Unit)
}
benchmark.addCase("DataFrame") { iter =>
@@ -55,15 +59,14 @@ object DatasetBenchmark {
res.queryExecution.toRdd.foreach(_ => Unit)
}
- val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
- benchmark.addCase("RDD") { iter =>
- var res = rdd
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
var i = 0
while (i < numChains) {
- res = rdd.map(func)
+ res = res.map(func)
i += 1
}
- res.foreach(_ => Unit)
+ res.queryExecution.toRdd.foreach(_ => Unit)
}
benchmark
@@ -74,19 +77,20 @@ object DatasetBenchmark {
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val benchmark = new Benchmark("back-to-back filter", numRows)
-
val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
val funcs = 0.until(numChains).map { i =>
(d: Data) => func(d, i)
}
- benchmark.addCase("Dataset") { iter =>
- var res = df.as[Data]
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
var i = 0
while (i < numChains) {
- res = res.filter(funcs(i))
+ res = rdd.filter(funcs(i))
i += 1
}
- res.queryExecution.toRdd.foreach(_ => Unit)
+ res.foreach(_ => Unit)
}
benchmark.addCase("DataFrame") { iter =>
@@ -99,15 +103,54 @@ object DatasetBenchmark {
res.queryExecution.toRdd.foreach(_ => Unit)
}
- val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
- benchmark.addCase("RDD") { iter =>
- var res = rdd
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
var i = 0
while (i < numChains) {
- res = rdd.filter(funcs(i))
+ res = res.filter(funcs(i))
i += 1
}
- res.foreach(_ => Unit)
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
+ object ComplexAggregator extends Aggregator[Data, Data, Long] {
+ override def zero: Data = Data(0, "")
+
+ override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "")
+
+ override def finish(reduction: Data): Long = reduction.l
+
+ override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "")
+
+ override def bufferEncoder: Encoder[Data] = Encoders.product[Data]
+
+ override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+ }
+
+ def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = {
+ import sqlContext.implicits._
+
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val benchmark = new Benchmark("aggregate", numRows)
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD sum") { iter =>
+ rdd.aggregate(0L)(_ + _.l, _ + _)
+ }
+
+ benchmark.addCase("DataFrame sum") { iter =>
+ df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("Dataset sum using Aggregator") { iter =>
+ df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("Dataset complex Aggregator") { iter =>
+ df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit)
}
benchmark
@@ -117,30 +160,45 @@ object DatasetBenchmark {
val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
val sqlContext = new SQLContext(sparkContext)
- val numRows = 10000000
+ val numRows = 100000000
val numChains = 10
val benchmark = backToBackMap(sqlContext, numRows, numChains)
val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+ val benchmark3 = aggregate(sqlContext, numRows)
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Dataset 902 / 995 11.1 90.2 1.0X
- DataFrame 132 / 167 75.5 13.2 6.8X
- RDD 216 / 237 46.3 21.6 4.2X
+ RDD 1935 / 2105 51.7 19.3 1.0X
+ DataFrame 756 / 799 132.3 7.6 2.6X
+ Dataset 7359 / 7506 13.6 73.6 0.3X
*/
benchmark.run()
/*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Dataset 585 / 628 17.1 58.5 1.0X
- DataFrame 62 / 80 160.7 6.2 9.4X
- RDD 205 / 220 48.7 20.5 2.8X
+ RDD 1974 / 2036 50.6 19.7 1.0X
+ DataFrame 103 / 127 967.4 1.0 19.1X
+ Dataset 4343 / 4477 23.0 43.4 0.5X
*/
benchmark2.run()
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+ aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ RDD sum 2130 / 2166 46.9 21.3 1.0X
+ DataFrame sum 92 / 128 1085.3 0.9 23.1X
+ Dataset sum using Aggregator 4111 / 4282 24.3 41.1 0.5X
+ Dataset complex Aggregator 8782 / 9036 11.4 87.8 0.2X
+ */
+ benchmark3.run()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 826862835a..23a0ce215f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
@@ -205,6 +206,7 @@ abstract class QueryTest extends PlanTest {
case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
+ case _: TypedAggregateExpression => return
case Literal(_, _: ObjectType) => return
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 4474cfcf6e..8efd9de29e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
+import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -99,4 +100,17 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
assert(ds.collect() === Array(0, 6))
}
+
+ test("simple typed UDAF should be included in WholeStageCodegen") {
+ import testImplicits._
+
+ val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ .groupByKey(_._1).agg(typed.sum(_._2))
+
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
+ assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
+ }
}