aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-05 21:50:35 -0700
committerReynold Xin <rxin@databricks.com>2015-08-05 21:50:35 -0700
commitd5a9af3230925c347d0904fe7f2402e468e80bc8 (patch)
tree43bb5d76cfbe758511923a0e983f93821ee5b99c
parent9270bd06fd0b16892e3f37213b5bc7813ea11fdd (diff)
downloadspark-d5a9af3230925c347d0904fe7f2402e468e80bc8.tar.gz
spark-d5a9af3230925c347d0904fe7f2402e468e80bc8.tar.bz2
spark-d5a9af3230925c347d0904fe7f2402e468e80bc8.zip
[SPARK-9664] [SQL] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction.
https://issues.apache.org/jira/browse/SPARK-9664 Author: Yin Huai <yhuai@databricks.com> Closes #7982 from yhuai/udafRegister and squashes the following commits: 0cc2287 [Yin Huai] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala1
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java26
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala4
8 files changed, 80 insertions, 46 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ffc2baf7a8..6f8ffb5440 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -291,9 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
val udf: UDFRegistration = new UDFRegistration(this)
- @transient
- val udaf: UDAFRegistration = new UDAFRegistration(this)
-
/**
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
deleted file mode 100644
index 0d4e30f292..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.{Expression}
-import org.apache.spark.sql.execution.aggregate.ScalaUDAF
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
-
-class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
-
- private val functionRegistry = sqlContext.functionRegistry
-
- def register(
- name: String,
- func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
- def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
- functionRegistry.registerFunction(name, builder)
- func
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 7cd7421a51..1f270560d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -26,6 +26,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.DataType
/**
@@ -52,6 +54,20 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
functionRegistry.registerFunction(name, udf.builder)
}
+ /**
+ * Register a user-defined aggregate function (UDAF).
+ * @param name the name of the UDAF.
+ * @param udaf the UDAF needs to be registered.
+ * @return the registered UDAF.
+ */
+ def register(
+ name: String,
+ udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+ def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
+ functionRegistry.registerFunction(name, builder)
+ udaf
+ }
+
// scalastyle:off
/* register 0-22 were generated by this script
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 5fafc916bf..7619f3ec9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -316,7 +316,7 @@ private[sql] case class ScalaUDAF(
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
- private[this] val childrenSchema: StructType = {
+ private[this] lazy val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
case (child, index) =>
StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
@@ -337,16 +337,16 @@ private[sql] case class ScalaUDAF(
}
}
- private[this] val inputToScalaConverters: Any => Any =
+ private[this] lazy val inputToScalaConverters: Any => Any =
CatalystTypeConverters.createToScalaConverter(childrenSchema)
- private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = {
+ private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = {
bufferSchema.fields.map { field =>
CatalystTypeConverters.createToCatalystConverter(field.dataType)
}
}
- private[this] val bufferValuesToScalaConverters: Array[Any => Any] = {
+ private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = {
bufferSchema.fields.map { field =>
CatalystTypeConverters.createToScalaConverter(field.dataType)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 278dd438fa..5180871585 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.types._
import org.apache.spark.annotation.Experimental
@@ -87,6 +90,33 @@ abstract class UserDefinedAggregateFunction extends Serializable {
* aggregation buffer.
*/
def evaluate(buffer: Row): Any
+
+ /**
+ * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments.
+ */
+ @scala.annotation.varargs
+ def apply(exprs: Column*): Column = {
+ val aggregateExpression =
+ AggregateExpression2(
+ ScalaUDAF(exprs.map(_.expr), this),
+ Complete,
+ isDistinct = false)
+ Column(aggregateExpression)
+ }
+
+ /**
+ * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments.
+ * If `isDistinct` is true, this UDAF is working on distinct input values.
+ */
+ @scala.annotation.varargs
+ def apply(isDistinct: Boolean, exprs: Column*): Column = {
+ val aggregateExpression =
+ AggregateExpression2(
+ ScalaUDAF(exprs.map(_.expr), this),
+ Complete,
+ isDistinct = isDistinct)
+ Column(aggregateExpression)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5a10c3891a..39aa905c85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2500,6 +2500,7 @@ object functions {
* @group udf_funcs
* @since 1.5.0
*/
+ @scala.annotation.varargs
def callUDF(udfName: String, cols: Column*): Column = {
UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
}
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
index 613b2bcc80..21b053f07a 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -29,8 +29,12 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Window;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.hive.test.TestHive$;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum;
public class JavaDataFrameSuite {
private transient JavaSparkContext sc;
@@ -77,4 +81,26 @@ public class JavaDataFrameSuite {
" ROWS BETWEEN 1 preceding and 1 following) " +
"FROM window_table").collectAsList());
}
+
+ @Test
+ public void testUDAF() {
+ DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value"));
+ UserDefinedAggregateFunction udaf = new MyDoubleSum();
+ UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
+ // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if
+ // we want to use distinct aggregation.
+ DataFrame aggregatedDF =
+ df.groupBy()
+ .agg(
+ udaf.apply(true, col("value")),
+ udaf.apply(col("value")),
+ registeredUDAF.apply(col("value")),
+ callUDF("mydoublesum", col("value")));
+
+ List<Row> expectedResult = new ArrayList<Row>();
+ expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));
+ checkAnswer(
+ aggregatedDF,
+ expectedResult);
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 6f0db27775..4b35c8fd83 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -73,8 +73,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
emptyDF.registerTempTable("emptyTable")
// Register UDAFs
- sqlContext.udaf.register("mydoublesum", new MyDoubleSum)
- sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg)
+ sqlContext.udf.register("mydoublesum", new MyDoubleSum)
+ sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
}
override def afterAll(): Unit = {