From 55946e76fd136958081f073c0c5e3ff8563d505b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 13:26:57 -0700 Subject: [SPARK-9349] [SQL] UDAF cleanup https://issues.apache.org/jira/browse/SPARK-9349 With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) and `MutableAggregationBuffer` (an interface). Other internal wrappers and helper classes are moved to `org.apache.spark.sql.execution.aggregate` and marked as `private[sql]`. Author: Yin Huai Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits: db36542 [Yin Huai] Add comments to UDAF examples. ae17f66 [Yin Huai] Address comments. 9c9fa5f [Yin Huai] UDAF cleanup. --- .../org/apache/spark/sql/UDAFRegistration.scala | 3 +- .../spark/sql/execution/aggregate/udaf.scala | 231 +++++++++++++++++ .../spark/sql/expressions/aggregate/udaf.scala | 287 --------------------- .../org/apache/spark/sql/expressions/udaf.scala | 101 ++++++++ .../spark/sql/hive/aggregate/MyDoubleAvg.java | 34 ++- .../spark/sql/hive/aggregate/MyDoubleSum.java | 28 +- 6 files changed, 385 insertions(+), 299 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala index 5b872f5e3e..0d4e30f292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala new file mode 100644 index 0000000000..073c45ae2f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} + +/** + * A Mutable [[Row]] representing an mutable aggregation buffer. + */ +private[sql] class MutableAggregationBufferImpl ( + schema: StructType, + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends MutableAggregationBuffer { + + private[this] val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } + + override def length: Int = toCatalystConverters.length + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( + schema, + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +private[sql] class InputAggregationBuffer private[sql] ( + schema: StructType, + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: InternalRow) + extends Row { + + private[this] val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } + + override def length: Int = toCatalystConverters.length + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + // TODO: Use buffer schema to avoid using generic getter. + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + schema, + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +private[sql] case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = + new MutableAggregationBufferImpl( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = udaf.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala deleted file mode 100644 index 4ada9eca7a..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.expressions.aggregate - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row - -/** - * The abstract class for implementing user-defined aggregate function. - */ -abstract class UserDefinedAggregateFunction extends Serializable { - - /** - * A [[StructType]] represents data types of input arguments of this aggregate function. - * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments - * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * input argument. Users can choose names to identify the input arguments. - */ - def inputSchema: StructType - - /** - * A [[StructType]] represents data types of values in the aggregation buffer. - * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values - * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], - * the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * buffer value. Users can choose names to identify the input arguments. - */ - def bufferSchema: StructType - - /** - * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. - */ - def returnDataType: DataType - - /** Indicates if this function is deterministic. */ - def deterministic: Boolean - - /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer should - * still store initial values. - */ - def initialize(buffer: MutableAggregationBuffer): Unit - - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ - def update(buffer: MutableAggregationBuffer, input: Row): Unit - - /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit - - /** - * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given - * aggregation buffer. - */ - def evaluate(buffer: Row): Any -} - -private[sql] abstract class AggregationBuffer( - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int) - extends Row { - - override def length: Int = toCatalystConverters.length - - protected val offsets: Array[Int] = { - val newOffsets = new Array[Int](length) - var i = 0 - while (i < newOffsets.length) { - newOffsets(i) = bufferOffset + i - i += 1 - } - newOffsets - } -} - -/** - * A Mutable [[Row]] representing an mutable aggregation buffer. - */ -class MutableAggregationBuffer private[sql] ( - schema: StructType, - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingBuffer: MutableRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { - - override def get(i: Int): Any = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not access ${i}th value in this buffer because it only has $length values.") - } - toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) - } - - def update(i: Int, value: Any): Unit = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not update ${i}th value in this buffer because it only has $length values.") - } - underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) - } - - override def copy(): MutableAggregationBuffer = { - new MutableAggregationBuffer( - schema, - toCatalystConverters, - toScalaConverters, - bufferOffset, - underlyingBuffer) - } -} - -/** - * A [[Row]] representing an immutable aggregation buffer. - */ -class InputAggregationBuffer private[sql] ( - schema: StructType, - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingInputBuffer: InternalRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { - - override def get(i: Int): Any = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not access ${i}th value in this buffer because it only has $length values.") - } - // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) - } - - override def copy(): InputAggregationBuffer = { - new InputAggregationBuffer( - schema, - toCatalystConverters, - toScalaConverters, - bufferOffset, - underlyingInputBuffer) - } -} - -/** - * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the - * internal aggregation code path. - * @param children - * @param udaf - */ -case class ScalaUDAF( - children: Seq[Expression], - udaf: UserDefinedAggregateFunction) - extends AggregateFunction2 with Logging { - - require( - children.length == udaf.inputSchema.length, - s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + - s"but ${children.length} are provided.") - - override def nullable: Boolean = true - - override def dataType: DataType = udaf.returnDataType - - override def deterministic: Boolean = udaf.deterministic - - override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) - - override val bufferSchema: StructType = udaf.bufferSchema - - override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes - - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - - val childrenSchema: StructType = { - val inputFields = children.zipWithIndex.map { - case (child, index) => - StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) - } - StructType(inputFields) - } - - lazy val inputProjection = { - val inputAttributes = childrenSchema.toAttributes - log.debug( - s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - try { - GenerateMutableProjection.generate(children, inputAttributes)() - } catch { - case e: Exception => - log.error("Failed to generate mutable projection, fallback to interpreted", e) - new InterpretedMutableProjection(children, inputAttributes) - } - } - - val inputToScalaConverters: Any => Any = - CatalystTypeConverters.createToScalaConverter(childrenSchema) - - val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToCatalystConverter(field.dataType) - } - - val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToScalaConverter(field.dataType) - } - - lazy val inputAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - bufferOffset, - null) - - lazy val mutableAggregateBuffer: MutableAggregationBuffer = - new MutableAggregationBuffer( - bufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - bufferOffset, - null) - - - override def initialize(buffer: MutableRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer - - udaf.initialize(mutableAggregateBuffer) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer - - udaf.update( - mutableAggregateBuffer, - inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) - } - - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer1 - inputAggregateBuffer.underlyingInputBuffer = buffer2 - - udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) - } - - override def eval(buffer: InternalRow = null): Any = { - inputAggregateBuffer.underlyingInputBuffer = buffer - - udaf.evaluate(inputAggregateBuffer) - } - - override def toString: String = { - s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" - } - - override def nodeName: String = udaf.getClass.getSimpleName -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 0000000000..278dd438fa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The abstract class for implementing user-defined aggregate functions. + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + */ +@Experimental +trait MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 5c9d0e97a9..a2247e3da1 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -21,13 +21,18 @@ import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Row; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ public class MyDoubleAvg extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleAvg() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); @@ -66,16 +74,23 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); + // The initial value of the count is 0. buffer.update(1, 0L); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. if (buffer.isNullAt(0)) { buffer.update(0, input.getDouble(0)); buffer.update(1, 1L); } else { + // Otherwise, update the bufferSum and increment bufferCount. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); buffer.update(1, buffer.getLong(1) + 1L); @@ -84,11 +99,16 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); buffer1.update(1, buffer2.getLong(1)); } else { + // Otherwise, we update the bufferSum and bufferCount. Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); @@ -98,10 +118,12 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. return null; } else { + // Otherwise, we calculate the special average value. return buffer.getDouble(0) / buffer.getLong(1) + 100.0; } } } - diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index 1d4587a27c..da29e24d26 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.Row; +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ public class MyDoubleSum extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleSum() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); @@ -65,14 +69,20 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. buffer.update(0, input.getDouble(0)); } else { + // Otherwise, we add the input value to the buffer value. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); } @@ -80,10 +90,16 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); } @@ -92,8 +108,10 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. return null; } else { + // Otherwise, the intermediate sum is the final result. return buffer.getDouble(0); } } -- cgit v1.2.3