aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-07-27 13:26:57 -0700
committerReynold Xin <rxin@databricks.com>2015-07-27 13:26:57 -0700
commit55946e76fd136958081f073c0c5e3ff8563d505b (patch)
treeaaf73aefcfe18bdc87b7e406f0f67e66e296450b /sql/hive
parentfa84e4a7ba6eab476487185178a556e4f04e4199 (diff)
downloadspark-55946e76fd136958081f073c0c5e3ff8563d505b.tar.gz
spark-55946e76fd136958081f073c0c5e3ff8563d505b.tar.bz2
spark-55946e76fd136958081f073c0c5e3ff8563d505b.zip
[SPARK-9349] [SQL] UDAF cleanup
https://issues.apache.org/jira/browse/SPARK-9349 With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) and `MutableAggregationBuffer` (an interface). Other internal wrappers and helper classes are moved to `org.apache.spark.sql.execution.aggregate` and marked as `private[sql]`. Author: Yin Huai <yhuai@databricks.com> Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits: db36542 [Yin Huai] Add comments to UDAF examples. ae17f66 [Yin Huai] Address comments. 9c9fa5f [Yin Huai] UDAF cleanup.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java34
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java28
2 files changed, 51 insertions, 11 deletions
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
index 5c9d0e97a9..a2247e3da1 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -21,13 +21,18 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
+ * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
+ * of the average value of input values and 100.0.
+ */
public class MyDoubleAvg extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleAvg() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
+ // The buffer has two values, bufferSum for storing the current sum and
+ // bufferCount for storing the number of non-null input values that have been contribuetd
+ // to the current sum.
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
@@ -66,16 +74,23 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
+ // The initial value of the count is 0.
buffer.update(1, 0L);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer and set the bufferCount to 1.
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
buffer.update(1, 1L);
} else {
+ // Otherwise, update the bufferSum and increment bufferCount.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
buffer.update(1, buffer.getLong(1) + 1L);
@@ -84,11 +99,16 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's sum value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
buffer1.update(1, buffer2.getLong(1));
} else {
+ // Otherwise, we update the bufferSum and bufferCount.
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
@@ -98,10 +118,12 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the bufferSum is still null, we return null because this function has not got
+ // any input row.
return null;
} else {
+ // Otherwise, we calculate the special average value.
return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
}
}
}
-
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
index 1d4587a27c..da29e24d26 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate;
import java.util.ArrayList;
import java.util.List;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
+ * {@link org.apache.spark.sql.types.DoubleType} column.
+ */
public class MyDoubleSum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
@@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
private DataType _returnDataType;
public MyDoubleSum() {
- List<StructField> inputfields = new ArrayList<StructField>();
- inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputfields);
+ List<StructField> inputFields = new ArrayList<StructField>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
@@ -65,14 +69,20 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer.
buffer.update(0, input.getDouble(0));
} else {
+ // Otherwise, we add the input value to the buffer value.
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
@@ -80,10 +90,16 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
buffer1.update(0, buffer2.getDouble(0));
} else {
+ // Otherwise, we add the input buffer's value (buffer1) to the mutable
+ // buffer's value (buffer2).
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
@@ -92,8 +108,10 @@ public class MyDoubleSum extends UserDefinedAggregateFunction {
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
+ // If the buffer value is still null, we return null.
return null;
} else {
+ // Otherwise, the intermediate sum is the final result.
return buffer.getDouble(0);
}
}