aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-23 10:39:33 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-23 10:39:33 -0800
commit946b406519af58c79041217e6f93854b6cf80acd (patch)
tree81d03b4d8c5884da8259550a190b363adbf83d78
parentf2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca (diff)
downloadspark-946b406519af58c79041217e6f93854b6cf80acd.tar.gz
spark-946b406519af58c79041217e6f93854b6cf80acd.tar.bz2
spark-946b406519af58c79041217e6f93854b6cf80acd.zip
[SPARK-11913][SQL] support typed aggregate with complex buffer schema
Author: Wenchen Fan <wenchen@databricks.com> Closes #9898 from cloud-fan/agg.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala41
2 files changed, 56 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 6ce41aaf01..a9719128a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -23,9 +23,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -46,14 +45,12 @@ object TypedAggregateExpression {
/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
- * - It assumes the aggregator reduces and returns a single column of type `long`.
- * - It might only work when there is a single aggregator in the first column.
* - It assumes the aggregator has a zero, `0`.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
- bEncoder: ExpressionEncoder[Any], // Should be bound.
+ unresolvedBEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
@@ -80,10 +77,14 @@ case class TypedAggregateExpression(
override lazy val inputTypes: Seq[DataType] = Nil
- override val aggBufferSchema: StructType = bEncoder.schema
+ override val aggBufferSchema: StructType = unresolvedBEncoder.schema
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+ val bEncoder = unresolvedBEncoder
+ .resolve(aggBufferAttributes, OuterScopes.outerScopes)
+ .bind(aggBufferAttributes)
+
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
@@ -93,12 +94,18 @@ case class TypedAggregateExpression(
lazy val boundA = aEncoder.get
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
- // todo: need a more neat way to assign the value.
var i = 0
while (i < aggBufferAttributes.length) {
+ val offset = mutableAggBufferOffset + i
aggBufferSchema(i).dataType match {
- case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
- case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
+ case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
+ case ByteType => buffer.setByte(offset, value.getByte(i))
+ case ShortType => buffer.setShort(offset, value.getShort(i))
+ case IntegerType => buffer.setInt(offset, value.getInt(i))
+ case LongType => buffer.setLong(offset, value.getLong(i))
+ case FloatType => buffer.setFloat(offset, value.getFloat(i))
+ case DoubleType => buffer.setDouble(offset, value.getDouble(i))
+ case other => buffer.update(offset, value.get(i, other))
}
i += 1
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 9377589790..19dce5d1e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L
}
case class AggData(a: Int, b: String)
-object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
+object ClassInputAgg extends Aggregator[AggData, Int, Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0
@@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
override def merge(b1: Int, b2: Int): Int = b1 + b2
}
+object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
+ /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
+ override def zero: (Int, AggData) = 0 -> AggData(0, "0")
+
+ /**
+ * Combine two values to produce a new value. For performance, the function may modify `b` and
+ * return it instead of constructing new object for b.
+ */
+ override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
+
+ /**
+ * Transform the output of the reduction.
+ */
+ override def finish(reduction: (Int, AggData)): Int = reduction._1
+
+ /**
+ * Merge two intermediate values
+ */
+ override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
+ (b1._1 + b2._1, b1._2)
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}
+
+ test("typed aggregation: complex input") {
+ val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+
+ checkAnswer(
+ ds.select(ComplexBufferAgg.toColumn),
+ 2
+ )
+
+ checkAnswer(
+ ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
+ (1.5, 2))
+
+ checkAnswer(
+ ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
+ ("one", 1), ("two", 1))
+ }
}