aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-07 13:19:49 -0700
committerYin Huai <yhuai@databricks.com>2015-10-07 13:19:49 -0700
commita9ecd06149df4ccafd3927c35f63b9f03f170ae5 (patch)
tree99d41c5b5f8cd786e2c161c339373a740b4048eb /sql/hive
parent5be5d247440d6346d667c4b3d817666126f62906 (diff)
downloadspark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.gz
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.bz2
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.zip
[SPARK-10941] [SQL] Refactor AggregateFunction2 and AlgebraicAggregate interfaces to improve code clarity
This patch refactors several of the Aggregate2 interfaces in order to improve code clarity. The biggest change is a refactoring of the `AggregateFunction2` class hierarchy. In the old code, we had a class named `AlgebraicAggregate` that inherited from `AggregateFunction2`, added a new set of methods, then banned the use of the inherited methods. I found this to be fairly confusing because. If you look carefully at the existing code, you'll see that subclasses of `AggregateFunction2` fall into two disjoint categories: imperative aggregation functions which directly extended `AggregateFunction2` and declarative, expression-based aggregate functions which extended `AlgebraicAggregate`. In order to make this more explicit, this patch refactors things so that `AggregateFunction2` is a sealed abstract class with two subclasses, `ImperativeAggregateFunction` and `ExpressionAggregateFunction`. The superclass, `AggregateFunction2`, now only contains methods and fields that are common to both subclasses. After making this change, I updated the various AggregationIterator classes to comply with this new naming scheme. I also performed several small renamings in the aggregate interfaces themselves in order to improve clarity and rewrote or expanded a number of comments. Author: Josh Rosen <joshrosen@databricks.com> Closes #8973 from JoshRosen/tungsten-agg-comments.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala4
2 files changed, 5 insertions, 7 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a85d4db88d..18bbdb9908 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -554,7 +554,7 @@ private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction2 with HiveInspectors {
+ extends ImperativeAggregate with HiveInspectors {
def this() = this(null, null)
@@ -598,7 +598,7 @@ private[hive] case class HiveUDAFFunction(
// Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
// buffer for it.
- override def bufferSchema: StructType = StructType(Nil)
+ override def aggBufferSchema: StructType = StructType(Nil)
override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
@@ -610,13 +610,11 @@ private[hive] case class HiveUDAFFunction(
"Hive UDAF doesn't support partial aggregate")
}
- override def cloneBufferAttributes: Seq[Attribute] = Nil
-
override def initialize(_buffer: MutableRow): Unit = {
buffer = function.getNewAggregationBuffer
}
- override def bufferAttributes: Seq[AttributeReference] = Nil
+ override def aggBufferAttributes: Seq[AttributeReference] = Nil
// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 24b1846923..c9e1bb1995 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -343,7 +343,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, null, 110.0, null, null, 10.0) :: Nil)
}
- test("non-AlgebraicAggregate aggreguate function") {
+ test("interpreted aggregate function") {
checkAnswer(
sqlContext.sql(
"""
@@ -368,7 +368,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null) :: Nil)
}
- test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
+ test("interpreted and expression-based aggregation functions") {
checkAnswer(
sqlContext.sql(
"""