aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
committerReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
commitf414154418c2291448954b9f0890d592b2d823ae (patch)
tree1663d938faacb33b1607e4beb0e9ec5afdf3f493 /sql/core/src/main
parentfa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff)
downloadspark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz
spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2
spark-f414154418c2291448954b9f0890d592b2d823ae.zip
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request? In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types. ## How was this patch tested? Added unit tests for them. Author: Reynold Xin <rxin@databricks.com> Closes #12077 from rxin/SPARK-14285.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala69
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala4
6 files changed, 202 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
new file mode 100644
index 0000000000..9afc29038b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.expressions.Aggregator
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines internal implementations for aggregators.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
+ val numeric = implicitly[Numeric[OUT]]
+ override def zero: OUT = numeric.zero
+ override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+ override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+ override def finish(reduction: OUT): OUT = reduction
+}
+
+
+class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
+ override def zero: Double = 0.0
+ override def reduce(b: Double, a: IN): Double = b + f(a)
+ override def merge(b1: Double, b2: Double): Double = b1 + b2
+ override def finish(reduction: Double): Double = reduction
+}
+
+
+class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0L
+ override def reduce(b: Long, a: IN): Long = b + f(a)
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0
+ override def reduce(b: Long, a: IN): Long = {
+ if (f(a) == null) b else b + 1
+ }
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
+ override def zero: (Double, Long) = (0.0, 0L)
+ override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
+ override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
+ override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
+ (b1._1 + b2._1, b1._2 + b2._2)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index e9b60841fc..350c283646 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -42,7 +42,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
spec.partitionBy(colName, colNames : _*)
}
@@ -51,7 +51,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
spec.partitionBy(cols : _*)
}
@@ -60,7 +60,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
spec.orderBy(colName, colNames : _*)
}
@@ -69,7 +69,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
spec.orderBy(cols : _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 9e9c58cb66..d716da2668 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
partitionBy((colName +: colNames).map(Column(_)): _*)
}
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
new WindowSpec(cols.map(_.expr), orderSpec, frame)
}
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
orderBy((colName +: colNames).map(Column(_)): _*)
}
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
val sortOrder: Seq[SortOrder] = cols.map { col =>
col.expr match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
new file mode 100644
index 0000000000..cdba970d8f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.java;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.Dataset;
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for {@link Dataset} operations in Java.
+ *
+ * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}.
+ *
+ * @since 2.0.0
+ */
+@Experimental
+public class typed {
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
new file mode 100644
index 0000000000..d0eb190afd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.scala
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.aggregate._
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for [[Dataset]] operations in Scala.
+ *
+ * Java users should use [[org.apache.spark.sql.expressions.java.typed]].
+ *
+ * @since 2.0.0
+ */
+@Experimental
+// scalastyle:off
+object typed {
+ // scalastyle:on
+
+ // Note: whenever we update this file, we should update the corresponding Java version too.
+ // The reason we have separate files for Java and Scala is because in the Scala version, we can
+ // use tighter types (primitive types) for return types, whereas in the Java version we can only
+ // use boxed primitive types.
+ // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode
+ // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double.
+
+ // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders.
+ private val implicits = new SQLImplicits {
+ override protected def _sqlContext: SQLContext = null
+ }
+
+ import implicits._
+
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn
+
+ // TODO:
+ // stddevOf: Double
+ // varianceOf: Double
+ // approxCountDistinct: Long
+
+ // minOf: T
+ // maxOf: T
+
+ // firstOf: T
+ // lastOf: T
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8b355befc3..48925910ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(