aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-10 11:14:25 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-10 11:14:25 -0800
commitdfcfcbcc0448ebc6f02eba6bf0495832a321c87e (patch)
treece199b98e0397da47033ec347789f56a4088f31f /sql
parent47735cdc2a878cfdbe76316d3ff8314a45dabf54 (diff)
downloadspark-dfcfcbcc0448ebc6f02eba6bf0495832a321c87e.tar.gz
spark-dfcfcbcc0448ebc6f02eba6bf0495832a321c87e.tar.bz2
spark-dfcfcbcc0448ebc6f02eba6bf0495832a321c87e.zip
[SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation
Currently the user facing api for typed aggregation has some limitations: * the customized typed aggregation must be the first of aggregation list * the customized typed aggregation can only use long as buffer type * the customized typed aggregation can only use flat type as result type This PR tries to remove these limitations. Author: Wenchen Fan <wenchen@databricks.com> Closes #9599 from cloud-fan/agg.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala50
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala52
4 files changed, 99 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index c287aebeee..005c0627f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -185,6 +185,12 @@ case class ExpressionEncoder[T](
})
}
+ def shift(delta: Int): ExpressionEncoder[T] = {
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
+ })
+ }
+
/**
* Returns a copy of this encoder where the expressions used to create an object given an
* input row have been modified to pull the object out from a nested struct, instead of the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 24d8122b62..0e5bc1f9ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
import org.apache.spark.Logging
+import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{StructType, DataType}
+import org.apache.spark.sql.types._
object TypedAggregateExpression {
def apply[A, B : Encoder, C : Encoder](
@@ -67,8 +67,11 @@ case class TypedAggregateExpression(
override def nullable: Boolean = true
- // TODO: this assumes flat results...
- override def dataType: DataType = cEncoder.schema.head.dataType
+ override def dataType: DataType = if (cEncoder.flat) {
+ cEncoder.schema.head.dataType
+ } else {
+ cEncoder.schema
+ }
override def deterministic: Boolean = true
@@ -93,32 +96,51 @@ case class TypedAggregateExpression(
case a: AttributeReference => inputMapping(a)
})
- // TODO: this probably only works when we are in the first column.
val bAttributes = bEncoder.schema.toAttributes
lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
+ private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
+ // todo: need a more neat way to assign the value.
+ var i = 0
+ while (i < aggBufferAttributes.length) {
+ aggBufferSchema(i).dataType match {
+ case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
+ case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
+ }
+ i += 1
+ }
+ }
+
override def initialize(buffer: MutableRow): Unit = {
- // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for
- // this in execution.
- buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
+ val zero = bEncoder.toRow(aggregator.zero)
+ updateBuffer(buffer, zero)
}
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
- val currentB = boundB.fromRow(buffer)
+ val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
val returned = boundB.toRow(merged)
- buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
+
+ updateBuffer(buffer, returned)
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- buffer1.setLong(
- mutableAggBufferOffset,
- buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset))
+ val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
+ val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+ val merged = aggregator.merge(b1, b2)
+ val returned = boundB.toRow(merged)
+
+ updateBuffer(buffer1, returned)
}
override def eval(buffer: InternalRow): Any = {
- buffer.getInt(mutableAggBufferOffset)
+ val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+ val result = cEncoder.toRow(aggregator.present(b))
+ dataType match {
+ case _: StructType => result
+ case _ => result.get(0, dataType)
+ }
}
override def toString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 8cc25c2440..3c1c457e06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -58,6 +58,11 @@ abstract class Aggregator[-A, B, C] {
def reduce(b: B, a: A): B
/**
+ * Merge two intermediate values
+ */
+ def merge(b1: B, b2: B): B
+
+ /**
* Transform the output of the reduction.
*/
def present(reduction: B): C
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 340470c096..206095a519 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+ override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
+
override def present(reduction: N): N = reduction
}
+object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable {
+ override def zero: (Long, Long) = (0, 0)
+
+ override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
+ (countAndSum._1 + 1, countAndSum._2 + input._2)
+ }
+
+ override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
+ (b1._1 + b2._1, b1._2 + b2._2)
+ }
+
+ override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
+}
+
+object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)]
+ with Serializable {
+
+ override def zero: (Long, Long) = (0, 0)
+
+ override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
+ (countAndSum._1 + 1, countAndSum._2 + input._2)
+ }
+
+ override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
+ (b1._1 + b2._1, b1._2 + b2._2)
+ }
+
+ override def present(reduction: (Long, Long)): (Long, Long) = reduction
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
count("*")),
("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
}
+
+ test("typed aggregation: complex case") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(
+ expr("avg(_2)").as[Double],
+ TypedAverage.toColumn),
+ ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+ }
+
+ test("typed aggregation: complex result type") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(
+ expr("avg(_2)").as[Double],
+ ComplexResultAgg.toColumn),
+ ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
+ }
}