diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala | 3 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala) | 122 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala | 101 | ||||
-rw-r--r-- | sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java | 34 | ||||
-rw-r--r-- | sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java | 28 |
5 files changed, 187 insertions, 101 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala index 5b872f5e3e..0d4e30f292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 4ada9eca7a..073c45ae2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -15,87 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.expressions.aggregate +package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} /** - * The abstract class for implementing user-defined aggregate function. + * A Mutable [[Row]] representing an mutable aggregation buffer. */ -abstract class UserDefinedAggregateFunction extends Serializable { - - /** - * A [[StructType]] represents data types of input arguments of this aggregate function. - * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments - * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * input argument. Users can choose names to identify the input arguments. - */ - def inputSchema: StructType - - /** - * A [[StructType]] represents data types of values in the aggregation buffer. - * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values - * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], - * the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * buffer value. Users can choose names to identify the input arguments. - */ - def bufferSchema: StructType - - /** - * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. - */ - def returnDataType: DataType - - /** Indicates if this function is deterministic. */ - def deterministic: Boolean - - /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer should - * still store initial values. - */ - def initialize(buffer: MutableAggregationBuffer): Unit - - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ - def update(buffer: MutableAggregationBuffer, input: Row): Unit - - /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit - - /** - * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given - * aggregation buffer. - */ - def evaluate(buffer: Row): Any -} - -private[sql] abstract class AggregationBuffer( +private[sql] class MutableAggregationBufferImpl ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], - bufferOffset: Int) - extends Row { - - override def length: Int = toCatalystConverters.length + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends MutableAggregationBuffer { - protected val offsets: Array[Int] = { + private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) var i = 0 while (i < newOffsets.length) { @@ -104,18 +46,8 @@ private[sql] abstract class AggregationBuffer( } newOffsets } -} -/** - * A Mutable [[Row]] representing an mutable aggregation buffer. - */ -class MutableAggregationBuffer private[sql] ( - schema: StructType, - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingBuffer: MutableRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { @@ -133,8 +65,8 @@ class MutableAggregationBuffer private[sql] ( underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) } - override def copy(): MutableAggregationBuffer = { - new MutableAggregationBuffer( + override def copy(): MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( schema, toCatalystConverters, toScalaConverters, @@ -146,13 +78,25 @@ class MutableAggregationBuffer private[sql] ( /** * A [[Row]] representing an immutable aggregation buffer. */ -class InputAggregationBuffer private[sql] ( +private[sql] class InputAggregationBuffer private[sql] ( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingInputBuffer: InternalRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + extends Row { + + private[this] val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } + + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { @@ -179,7 +123,7 @@ class InputAggregationBuffer private[sql] ( * @param children * @param udaf */ -case class ScalaUDAF( +private[sql] case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction) extends AggregateFunction2 with Logging { @@ -243,8 +187,8 @@ case class ScalaUDAF( bufferOffset, null) - lazy val mutableAggregateBuffer: MutableAggregationBuffer = - new MutableAggregationBuffer( + lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = + new MutableAggregationBufferImpl( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 0000000000..278dd438fa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The abstract class for implementing user-defined aggregate functions. + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + */ +@Experimental +trait MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 5c9d0e97a9..a2247e3da1 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -21,13 +21,18 @@ import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Row; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ public class MyDoubleAvg extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleAvg() { - List<StructField> inputfields = new ArrayList<StructField>(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List<StructField> inputFields = new ArrayList<StructField>(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. List<StructField> bufferFields = new ArrayList<StructField>(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); @@ -66,16 +74,23 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); + // The initial value of the count is 0. buffer.update(1, 0L); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. if (buffer.isNullAt(0)) { buffer.update(0, input.getDouble(0)); buffer.update(1, 1L); } else { + // Otherwise, update the bufferSum and increment bufferCount. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); buffer.update(1, buffer.getLong(1) + 1L); @@ -84,11 +99,16 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); buffer1.update(1, buffer2.getLong(1)); } else { + // Otherwise, we update the bufferSum and bufferCount. Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); @@ -98,10 +118,12 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. return null; } else { + // Otherwise, we calculate the special average value. return buffer.getDouble(0) / buffer.getLong(1) + 100.0; } } } - diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index 1d4587a27c..da29e24d26 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.Row; +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ public class MyDoubleSum extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleSum() { - List<StructField> inputfields = new ArrayList<StructField>(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List<StructField> inputFields = new ArrayList<StructField>(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); List<StructField> bufferFields = new ArrayList<StructField>(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); @@ -65,14 +69,20 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. buffer.update(0, input.getDouble(0)); } else { + // Otherwise, we add the input value to the buffer value. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); } @@ -80,10 +90,16 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); } @@ -92,8 +108,10 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. return null; } else { + // Otherwise, the intermediate sum is the final result. return buffer.getDouble(0); } } |