aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-07-27 13:26:57 -0700
committerReynold Xin <rxin@databricks.com>2015-07-27 13:26:57 -0700
commit55946e76fd136958081f073c0c5e3ff8563d505b (patch)
treeaaf73aefcfe18bdc87b7e406f0f67e66e296450b
parentfa84e4a7ba6eab476487185178a556e4f04e4199 (diff)
downloadspark-55946e76fd136958081f073c0c5e3ff8563d505b.tar.gz
spark-55946e76fd136958081f073c0c5e3ff8563d505b.tar.bz2
spark-55946e76fd136958081f073c0c5e3ff8563d505b.zip
[SPARK-9349] [SQL] UDAF cleanup
https://issues.apache.org/jira/browse/SPARK-9349 With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) and `MutableAggregationBuffer` (an interface). Other internal wrappers and helper classes are moved to `org.apache.spark.sql.execution.aggregate` and marked as `private[sql]`. Author: Yin Huai <yhuai@databricks.com> Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits: db36542 [Yin Huai] Add comments to UDAF examples. ae17f66 [Yin Huai] Address comments. 9c9fa5f [Yin Huai] UDAF cleanup.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala)122
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala101
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java34
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java28
5 files changed, 187 insertions, 101 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
index 5b872f5e3e..0d4e30f292 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression}
-import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 4ada9eca7a..073c45ae2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -15,87 +15,29 @@
* limitations under the License.
*/
-package org.apache.spark.sql.expressions.aggregate
+package org.apache.spark.sql.execution.aggregate
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
/**
- * The abstract class for implementing user-defined aggregate function.
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
*/
-abstract class UserDefinedAggregateFunction extends Serializable {
-
- /**
- * A [[StructType]] represents data types of input arguments of this aggregate function.
- * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
- * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
- *
- * ```
- * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
- * ```
- *
- * The name of a field of this [[StructType]] is only used to identify the corresponding
- * input argument. Users can choose names to identify the input arguments.
- */
- def inputSchema: StructType
-
- /**
- * A [[StructType]] represents data types of values in the aggregation buffer.
- * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
- * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
- * the returned [[StructType]] will look like
- *
- * ```
- * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
- * ```
- *
- * The name of a field of this [[StructType]] is only used to identify the corresponding
- * buffer value. Users can choose names to identify the input arguments.
- */
- def bufferSchema: StructType
-
- /**
- * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
- */
- def returnDataType: DataType
-
- /** Indicates if this function is deterministic. */
- def deterministic: Boolean
-
- /**
- * Initializes the given aggregation buffer. Initial values set by this method should satisfy
- * the condition that when merging two buffers with initial values, the new buffer should
- * still store initial values.
- */
- def initialize(buffer: MutableAggregationBuffer): Unit
-
- /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
- def update(buffer: MutableAggregationBuffer, input: Row): Unit
-
- /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
- def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
-
- /**
- * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
- * aggregation buffer.
- */
- def evaluate(buffer: Row): Any
-}
-
-private[sql] abstract class AggregationBuffer(
+private[sql] class MutableAggregationBufferImpl (
+ schema: StructType,
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
- bufferOffset: Int)
- extends Row {
-
- override def length: Int = toCatalystConverters.length
+ bufferOffset: Int,
+ var underlyingBuffer: MutableRow)
+ extends MutableAggregationBuffer {
- protected val offsets: Array[Int] = {
+ private[this] val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
var i = 0
while (i < newOffsets.length) {
@@ -104,18 +46,8 @@ private[sql] abstract class AggregationBuffer(
}
newOffsets
}
-}
-/**
- * A Mutable [[Row]] representing an mutable aggregation buffer.
- */
-class MutableAggregationBuffer private[sql] (
- schema: StructType,
- toCatalystConverters: Array[Any => Any],
- toScalaConverters: Array[Any => Any],
- bufferOffset: Int,
- var underlyingBuffer: MutableRow)
- extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+ override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
if (i >= length || i < 0) {
@@ -133,8 +65,8 @@ class MutableAggregationBuffer private[sql] (
underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
}
- override def copy(): MutableAggregationBuffer = {
- new MutableAggregationBuffer(
+ override def copy(): MutableAggregationBufferImpl = {
+ new MutableAggregationBufferImpl(
schema,
toCatalystConverters,
toScalaConverters,
@@ -146,13 +78,25 @@ class MutableAggregationBuffer private[sql] (
/**
* A [[Row]] representing an immutable aggregation buffer.
*/
-class InputAggregationBuffer private[sql] (
+private[sql] class InputAggregationBuffer private[sql] (
schema: StructType,
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingInputBuffer: InternalRow)
- extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+ extends Row {
+
+ private[this] val offsets: Array[Int] = {
+ val newOffsets = new Array[Int](length)
+ var i = 0
+ while (i < newOffsets.length) {
+ newOffsets(i) = bufferOffset + i
+ i += 1
+ }
+ newOffsets
+ }
+
+ override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
if (i >= length || i < 0) {
@@ -179,7 +123,7 @@ class InputAggregationBuffer private[sql] (
* @param children
* @param udaf
*/
-case class ScalaUDAF(
+private[sql] case class ScalaUDAF(
children: Seq[Expression],
udaf: UserDefinedAggregateFunction)
extends AggregateFunction2 with Logging {
@@ -243,8 +187,8 @@ case class ScalaUDAF(
bufferOffset,
null)
- lazy val mutableAggregateBuffer: MutableAggregationBuffer =
- new MutableAggregationBuffer(
+ lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
+ new MutableAggregationBufferImpl(
bufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
new file mode 100644
index 0000000000..278dd438fa
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * The abstract class for implementing user-defined aggregate functions.
+ */
+@Experimental
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+ /**
+ * A [[StructType]] represents data types of input arguments of this aggregate function.
+ * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+ * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+ *
+ * ```
+ * new StructType()
+ * .add("doubleInput", DoubleType)
+ * .add("longInput", LongType)
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * input argument. Users can choose names to identify the input arguments.
+ */
+ def inputSchema: StructType
+
+ /**
+ * A [[StructType]] represents data types of values in the aggregation buffer.
+ * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+ * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+ * the returned [[StructType]] will look like
+ *
+ * ```
+ * new StructType()
+ * .add("doubleInput", DoubleType)
+ * .add("longInput", LongType)
+ * ```
+ *
+ * The name of a field of this [[StructType]] is only used to identify the corresponding
+ * buffer value. Users can choose names to identify the input arguments.
+ */
+ def bufferSchema: StructType
+
+ /**
+ * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+ */
+ def returnDataType: DataType
+
+ /** Indicates if this function is deterministic. */
+ def deterministic: Boolean
+
+ /**
+ * Initializes the given aggregation buffer. Initial values set by this method should satisfy
+ * the condition that when merging two buffers with initial values, the new buffer
+ * still store initial values.
+ */
+ def initialize(buffer: MutableAggregationBuffer): Unit
+
+ /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+ /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+ /**
+ * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+ * aggregation buffer.
+ */
+ def evaluate(buffer: Row): Any
+}
+
+/**
+ * :: Experimental ::
+ * A [[Row]] representing an mutable aggregation buffer.
+ */
+@Experimental
+trait MutableAggregationBuffer extends Row {
+
+ /** Update the ith value of this buffer. */
+ def update(i: Int, value: Any): Unit
+}
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
index 5c9d0e97a9..a2247e3da1 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -21,13 +21,18 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
+ * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
+ * of the average value of input values and 100.0.
+ */
public class MyDoubleAvg extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleAvg() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
+ // The buffer has two values, bufferSum for storing the current sum and
+ // bufferCount for storing the number of non-null input values that have been contribuetd
+ // to the current sum.
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
@@ -66,16 +74,23 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
+ // The initial value of the count is 0.
buffer.update(1, 0L);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer and set the bufferCount to 1.
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
buffer.update(1, 1L);
} else {
+ // Otherwise, update the bufferSum and increment bufferCount.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
buffer.update(1, buffer.getLong(1) + 1L);
@@ -84,11 +99,16 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's sum value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
buffer1.update(1, buffer2.getLong(1));
} else {
+ // Otherwise, we update the bufferSum and bufferCount.
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
@@ -98,10 +118,12 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the bufferSum is still null, we return null because this function has not got
+ // any input row.
return null;
} else {
+ // Otherwise, we calculate the special average value.
return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
}
}
}
-
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
index 1d4587a27c..da29e24d26 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate;
import java.util.ArrayList;
import java.util.List;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
+ * {@link org.apache.spark.sql.types.DoubleType} column.
+ */
public class MyDoubleSum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleSum() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
@@ -65,14 +69,20 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer.
buffer.update(0, input.getDouble(0));
} else {
+ // Otherwise, we add the input value to the buffer value.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
@@ -80,10 +90,16 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
} else {
+ // Otherwise, we add the input buffer's value (buffer1) to the mutable
+ // buffer's value (buffer2).
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
@@ -92,8 +108,10 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the buffer value is still null, we return null.
return null;
} else {
+ // Otherwise, the intermediate sum is the final result.
return buffer.getDouble(0);
}
}