diff options
author | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
commit | f414154418c2291448954b9f0890d592b2d823ae (patch) | |
tree | 1663d938faacb33b1607e4beb0e9ec5afdf3f493 /sql/core/src/main | |
parent | fa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff) | |
download | spark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2 spark-f414154418c2291448954b9f0890d592b2d823ae.zip |
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request?
In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types.
## How was this patch tested?
Added unit tests for them.
Author: Reynold Xin <rxin@databricks.com>
Closes #12077 from rxin/SPARK-14285.
Diffstat (limited to 'sql/core/src/main')
6 files changed, 202 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala new file mode 100644 index 0000000000..9afc29038b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.expressions.Aggregator + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines internal implementations for aggregators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { + val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction +} + + +class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { + override def zero: Double = 0.0 + override def reduce(b: Double, a: IN): Double = b + f(a) + override def merge(b1: Double, b2: Double): Double = b1 + b2 + override def finish(reduction: Double): Double = reduction +} + + +class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: IN): Long = b + f(a) + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0 + override def reduce(b: Long, a: IN): Long = { + if (f(a) == null) b else b + 1 + } + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { + override def zero: (Double, Long) = (0.0, 0L) + override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) + override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 + override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e9b60841fc..350c283646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -42,7 +42,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +51,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +60,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,7 +69,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 9e9c58cb66..d716da2668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -39,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -48,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -57,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -66,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java new file mode 100644 index 0000000000..cdba970d8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.java; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.Dataset; + +/** + * :: Experimental :: + * Type-safe functions available for {@link Dataset} operations in Java. + * + * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * + * @since 2.0.0 + */ +@Experimental +public class typed { + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala new file mode 100644 index 0000000000..d0eb190afd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.scala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate._ + +/** + * :: Experimental :: + * Type-safe functions available for [[Dataset]] operations in Scala. + * + * Java users should use [[org.apache.spark.sql.expressions.java.typed]]. + * + * @since 2.0.0 + */ +@Experimental +// scalastyle:off +object typed { + // scalastyle:on + + // Note: whenever we update this file, we should update the corresponding Java version too. + // The reason we have separate files for Java and Scala is because in the Scala version, we can + // use tighter types (primitive types) for return types, whereas in the Java version we can only + // use boxed primitive types. + // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode + // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. + + // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. + private val implicits = new SQLImplicits { + override protected def _sqlContext: SQLContext = null + } + + import implicits._ + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn + + // TODO: + // stddevOf: Double + // varianceOf: Double + // approxCountDistinct: Long + + // minOf: T + // maxOf: T + + // firstOf: T + // lastOf: T +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 8b355befc3..48925910ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( @@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( |