aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-09 16:11:00 -0800
commit9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 (patch)
treec821f0b8bbcce9410bdc5b54968251f8bdfb0b6a /sql
parent2f38378856fb56bdd9be7ccedf56427e81701f4e (diff)
downloadspark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.gz
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.tar.bz2
spark-9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5.zip
[SPARK-11578][SQL] User API for Typed Aggregation
This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust <michael@databricks.com> Closes #9555 from marmbrus/dataset-useragg.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala129
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala30
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala65
9 files changed, 360 insertions, 42 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index c32c93897c..d26b6c3579 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
@@ -39,10 +39,13 @@ private[sql] object Column {
}
/**
- * A [[Column]] where an [[Encoder]] has been given for the expected return type.
+ * A [[Column]] where an [[Encoder]] has been given for the expected input and return type.
* @since 1.6.0
+ * @tparam T The input type expected for this expression. Can be `Any` if the expression is type
+ * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
+ * @tparam U The output type of this column.
*/
-class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr)
+class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr)
/**
* :: Experimental ::
@@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* results into the correct JVM types.
* @since 1.6.0
*/
- def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
+ def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U])
/**
* Extracts a value or values from a complex type.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 959e0f5ba0..6d2968e288 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -358,7 +358,7 @@ class Dataset[T] private[sql](
* }}}
* @since 1.6.0
*/
- def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
+ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
@@ -367,7 +367,7 @@ class Dataset[T] private[sql](
* code reuse, we do this without the help of the type system and then use helper functions
* that cast appropriately for the user facing interface.
*/
- protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
val unresolvedPlan = Project(aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan)
@@ -385,7 +385,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+ def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
@@ -393,9 +393,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3](
- c1: TypedColumn[U1],
- c2: TypedColumn[U2],
- c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
@@ -403,10 +403,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3, U4](
- c1: TypedColumn[U1],
- c2: TypedColumn[U2],
- c3: TypedColumn[U3],
- c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
@@ -414,11 +414,11 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3, U4, U5](
- c1: TypedColumn[U1],
- c2: TypedColumn[U2],
- c3: TypedColumn[U3],
- c4: TypedColumn[U4],
- c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4],
+ c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 850315e281..db61499229 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.util.{Iterator => JIterator}
+
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.QueryExecution
+
/**
* :: Experimental ::
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql](
* that cast appropriately for the user facing interface.
* TODO: does not handle aggrecations that return nonflat results,
*/
- protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
@@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql](
}
val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
- val execution = new QueryExecution(sqlContext, unresolvedPlan)
+
+ // Fill in the input encoders for any aggregators in the plan.
+ val withEncoders = unresolvedPlan transformAllExpressions {
+ case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+ ta.copy(
+ aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
+ children = dataAttributes)
+ }
+ val execution = new QueryExecution(sqlContext, withEncoders)
val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
@@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql](
case (e, a) =>
e.unbind(a :: Nil).resolve(execution.analyzed.output)
}
- new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+
+ new Dataset(
+ sqlContext,
+ execution,
+ ExpressionEncoder.tuple(encoders))
}
/**
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
* and the result of computing this aggregation over all elements in the group.
*/
- def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
- aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
+ def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
+ aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
- def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] =
- aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
+ def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
+ aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
- def agg[A1, A2, A3](
- col1: TypedColumn[A1],
- col2: TypedColumn[A2],
- col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
- aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
+ def agg[U1, U2, U3](
+ col1: TypedColumn[T, U1],
+ col2: TypedColumn[T, U2],
+ col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
+ aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
- def agg[A1, A2, A3, A4](
- col1: TypedColumn[A1],
- col2: TypedColumn[A2],
- col3: TypedColumn[A3],
- col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
- aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]]
+ def agg[U1, U2, U3, U4](
+ col1: TypedColumn[T, U1],
+ col2: TypedColumn[T, U2],
+ col3: TypedColumn[T, U3],
+ col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
+ aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
/**
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5598731af5..1cf1e30f96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
-
import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
new file mode 100644
index 0000000000..24d8122b62
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import scala.language.existentials
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StructType, DataType}
+
+object TypedAggregateExpression {
+ def apply[A, B : Encoder, C : Encoder](
+ aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
+ new TypedAggregateExpression(
+ aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+ None,
+ encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
+ encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
+ Nil,
+ 0,
+ 0)
+ }
+}
+
+/**
+ * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
+ * the following limitations:
+ * - It assumes the aggregator reduces and returns a single column of type `long`.
+ * - It might only work when there is a single aggregator in the first column.
+ * - It assumes the aggregator has a zero, `0`.
+ */
+case class TypedAggregateExpression(
+ aggregator: Aggregator[Any, Any, Any],
+ aEncoder: Option[ExpressionEncoder[Any]],
+ bEncoder: ExpressionEncoder[Any],
+ cEncoder: ExpressionEncoder[Any],
+ children: Seq[Expression],
+ mutableAggBufferOffset: Int,
+ inputAggBufferOffset: Int)
+ extends ImperativeAggregate with Logging {
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def nullable: Boolean = true
+
+ // TODO: this assumes flat results...
+ override def dataType: DataType = cEncoder.schema.head.dataType
+
+ override def deterministic: Boolean = true
+
+ override lazy val resolved: Boolean = aEncoder.isDefined
+
+ override lazy val inputTypes: Seq[DataType] =
+ aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
+
+ override val aggBufferSchema: StructType = bEncoder.schema
+
+ override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ lazy val inputAttributes = aEncoder.get.schema.toAttributes
+ lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
+ lazy val boundA =
+ aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform {
+ case a: AttributeReference => inputMapping(a)
+ })
+
+ // TODO: this probably only works when we are in the first column.
+ val bAttributes = bEncoder.schema.toAttributes
+ lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
+
+ override def initialize(buffer: MutableRow): Unit = {
+ // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for
+ // this in execution.
+ buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
+ }
+
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val inputA = boundA.fromRow(input)
+ val currentB = boundB.fromRow(buffer)
+ val merged = aggregator.reduce(currentB, inputA)
+ val returned = boundB.toRow(merged)
+ buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
+ }
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ buffer1.setLong(
+ mutableAggBufferOffset,
+ buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset))
+ }
+
+ override def eval(buffer: InternalRow): Any = {
+ buffer.getInt(mutableAggBufferOffset)
+ }
+
+ override def toString: String = {
+ s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
+ }
+
+ override def nodeName: String = aggregator.getClass.getSimpleName
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
new file mode 100644
index 0000000000..0b3192a6da
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
+
+/**
+ * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]]
+ * operations to take all of the elements of a group and reduce them to a single value.
+ *
+ * For example, the following aggregator extracts an `int` from a specific class and adds them up:
+ * {{{
+ * case class Data(i: Int)
+ *
+ * val customSummer = new Aggregator[Data, Int, Int] {
+ * def zero = 0
+ * def reduce(b: Int, a: Data) = b + a.i
+ * def present(r: Int) = r
+ * }.toColumn()
+ *
+ * val ds: Dataset[Data]
+ * val aggregated = ds.select(customSummer)
+ * }}}
+ *
+ * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
+ *
+ * @tparam A The input type for the aggregation.
+ * @tparam B The type of the intermediate value of the reduction.
+ * @tparam C The type of the final result.
+ */
+abstract class Aggregator[-A, B, C] {
+
+ /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
+ def zero: B
+
+ /**
+ * Combine two values to produce a new value. For performance, the function may modify `b` and
+ * return it instead of constructing new object for b.
+ */
+ def reduce(b: B, a: A): B
+
+ /**
+ * Transform the output of the reduction.
+ */
+ def present(reduction: B): C
+
+ /**
+ * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
+ * operations.
+ */
+ def toColumn(
+ implicit bEncoder: Encoder[B],
+ cEncoder: Encoder[C]): TypedColumn[A, C] = {
+ val expr =
+ new AggregateExpression2(
+ TypedAggregateExpression(this),
+ Complete,
+ false)
+
+ new TypedColumn[A, C](expr, encoderFor[C])
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3f0b24b68b..6d56542ee0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+
+
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import scala.util.Try
@@ -24,12 +26,33 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
+ * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have
+ * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate
+ * "bridge" methods due to the use of covariant return types.
+ *
+ * {{{
+ * In LegacyFunctions:
+ * public abstract org.apache.spark.sql.Column avg(java.lang.String);
+ *
+ * In functions:
+ * public static org.apache.spark.sql.TypedColumn<java.lang.Object, java.lang.Object> avg(...);
+ * }}}
+ *
+ * This allows us to use the same functions both in typed [[Dataset]] operations and untyped
+ * [[DataFrame]] operations when the return type for a given function is statically known.
+ */
+private[sql] abstract class LegacyFunctions {
+ def count(columnName: String): Column
+}
+
+/**
* :: Experimental ::
* Functions available for [[DataFrame]].
*
@@ -48,11 +71,14 @@ import org.apache.spark.util.Utils
*/
@Experimental
// scalastyle:off
-object functions {
+object functions extends LegacyFunctions {
// scalastyle:on
private def withExpr(expr: Expression): Column = Column(expr)
+ private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+
+
/**
* Returns a [[Column]] based on the given column name.
*
@@ -234,7 +260,7 @@ object functions {
* @group agg_funcs
* @since 1.3.0
*/
- def count(columnName: String): Column = count(Column(columnName))
+ def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long]
/**
* Aggregate function: returns the number of distinct items in a group.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 312cf33e4c..2da63d1b96 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable {
Dataset<Integer> ds = context.createDataset(data, e.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
- expr("value + 1").as(e.INT()),
- col("value").cast("string").as(e.STRING()));
+ expr("value + 1"),
+ col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
Assert.assertEquals(
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
new file mode 100644
index 0000000000..340470c096
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.functions._
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+import org.apache.spark.sql.expressions.Aggregator
+
+/** An `Aggregator` that adds up any numeric type returned by the given function. */
+class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+ val numeric = implicitly[Numeric[N]]
+
+ override def zero: N = numeric.zero
+
+ override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+
+ override def present(reduction: N): N = reduction
+}
+
+class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
+
+ import testImplicits._
+
+ def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+ new SumOf(f).toColumn
+
+ test("typed aggregation: TypedAggregator") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(sum(_._2)),
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+
+ test("typed aggregation: TypedAggregator, expr, expr") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(
+ sum(_._2),
+ expr("sum(_2)").as[Int],
+ count("*")),
+ ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
+ }
+}