From 9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 9 Nov 2015 16:11:00 -0800 Subject: [SPARK-11578][SQL] User API for Typed Aggregation This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust Closes #9555 from marmbrus/dataset-useragg. --- .../main/scala/org/apache/spark/sql/Column.scala | 11 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 30 ++--- .../org/apache/spark/sql/GroupedDataset.scala | 51 +++++--- .../scala/org/apache/spark/sql/SQLContext.scala | 1 - .../aggregate/TypedAggregateExpression.scala | 129 +++++++++++++++++++++ .../apache/spark/sql/expressions/Aggregator.scala | 81 +++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 30 ++++- .../org/apache/spark/sql/JavaDatasetSuite.java | 4 +- .../apache/spark/sql/DatasetAggregatorSuite.scala | 65 +++++++++++ 9 files changed, 360 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c32c93897c..d26b6c3579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -39,10 +39,13 @@ private[sql] object Column { } /** - * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. * @since 1.6.0 + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. */ -class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) +class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr) /** * :: Experimental :: @@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 959e0f5ba0..6d2968e288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -358,7 +358,7 @@ class Dataset[T] private[sql]( * }}} * @since 1.6.0 */ - def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) } @@ -367,7 +367,7 @@ class Dataset[T] private[sql]( * code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) @@ -385,7 +385,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** @@ -393,9 +393,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** @@ -403,10 +403,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** @@ -414,11 +414,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4], - c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 850315e281..db61499229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.{Iterator => JIterator} + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.QueryExecution + /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql]( * that cast appropriately for the user facing interface. * TODO: does not handle aggrecations that return nonflat results, */ - protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = (groupingAttributes ++ columns.map(_.expr)).map { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr @@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql]( } val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) + + // Fill in the input encoders for any aggregators in the plan. + val withEncoders = unresolvedPlan transformAllExpressions { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]), + children = dataAttributes) + } + val execution = new QueryExecution(sqlContext, withEncoders) val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) @@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql]( case (e, a) => e.unbind(a :: Nil).resolve(execution.analyzed.output) } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(encoders)) } /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. */ - def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + def agg[U1, U2, U3]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3, A4]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3], - col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + def agg[U1, U2, U3, U4]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3], + col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5598731af5..1cf1e30f96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 0000000000..24d8122b62 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructType, DataType} + +object TypedAggregateExpression { + def apply[A, B : Encoder, C : Encoder]( + aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], + Nil, + 0, + 0) + } +} + +/** + * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has + * the following limitations: + * - It assumes the aggregator reduces and returns a single column of type `long`. + * - It might only work when there is a single aggregator in the first column. + * - It assumes the aggregator has a zero, `0`. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + aEncoder: Option[ExpressionEncoder[Any]], + bEncoder: ExpressionEncoder[Any], + cEncoder: ExpressionEncoder[Any], + children: Seq[Expression], + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + // TODO: this assumes flat results... + override def dataType: DataType = cEncoder.schema.head.dataType + + override def deterministic: Boolean = true + + override lazy val resolved: Boolean = aEncoder.isDefined + + override lazy val inputTypes: Seq[DataType] = + aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil) + + override val aggBufferSchema: StructType = bEncoder.schema + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + lazy val inputAttributes = aEncoder.get.schema.toAttributes + lazy val inputMapping = AttributeMap(inputAttributes.zip(children)) + lazy val boundA = + aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform { + case a: AttributeReference => inputMapping(a) + }) + + // TODO: this probably only works when we are in the first column. + val bAttributes = bEncoder.schema.toAttributes + lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + + override def initialize(buffer: MutableRow): Unit = { + // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for + // this in execution. + buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputA = boundA.fromRow(input) + val currentB = boundB.fromRow(buffer) + val merged = aggregator.reduce(currentB, inputA) + val returned = boundB.toRow(merged) + buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + buffer1.setLong( + mutableAggBufferOffset, + buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + } + + override def eval(buffer: InternalRow): Any = { + buffer.getInt(mutableAggBufferOffset) + } + + override def toString: String = { + s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = aggregator.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 0000000000..0b3192a6da --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} + +/** + * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] + * operations to take all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero = 0 + * def reduce(b: Int, a: Data) = b + a.i + * def present(r: Int) = r + * }.toColumn() + * + * val ds: Dataset[Data] + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam A The input type for the aggregation. + * @tparam B The type of the intermediate value of the reduction. + * @tparam C The type of the final result. + */ +abstract class Aggregator[-A, B, C] { + + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + def zero: B + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + def reduce(b: B, a: A): B + + /** + * Transform the output of the reduction. + */ + def present(reduction: B): C + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * operations. + */ + def toColumn( + implicit bEncoder: Encoder[B], + cEncoder: Encoder[C]): TypedColumn[A, C] = { + val expr = + new AggregateExpression2( + TypedAggregateExpression(this), + Complete, + false) + + new TypedColumn[A, C](expr, encoderFor[C]) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3f0b24b68b..6d56542ee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql + + import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -24,11 +26,32 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have + * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate + * "bridge" methods due to the use of covariant return types. + * + * {{{ + * In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * + * In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); + * }}} + * + * This allows us to use the same functions both in typed [[Dataset]] operations and untyped + * [[DataFrame]] operations when the return type for a given function is statically known. + */ +private[sql] abstract class LegacyFunctions { + def count(columnName: String): Column +} + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -48,11 +71,14 @@ import org.apache.spark.util.Utils */ @Experimental // scalastyle:off -object functions { +object functions extends LegacyFunctions { // scalastyle:on private def withExpr(expr: Expression): Column = Column(expr) + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + + /** * Returns a [[Column]] based on the given column name. * @@ -234,7 +260,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long] /** * Aggregate function: returns the number of distinct items in a group. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 312cf33e4c..2da63d1b96 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable { Dataset ds = context.createDataset(data, e.INT()); Dataset> selected = ds.select( - expr("value + 1").as(e.INT()), - col("value").cast("string").as(e.STRING())); + expr("value + 1"), + col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 0000000000..340470c096 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.functions._ + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def present(reduction: N): N = reduction +} + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum(_._2)), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Int], + count("*")), + ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + } +} -- cgit v1.2.3