aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-10 16:28:21 -0800
committerYin Huai <yhuai@databricks.com>2015-11-10 16:28:21 -0800
commit21c562fa03430365f5c2b7d6de1f8f60ab2140d4 (patch)
treed974f6433f33a0a401293a43315b98f002e04548 /sql
parent3121e78168808c015fb21da8b0d44bb33649fb81 (diff)
downloadspark-21c562fa03430365f5c2b7d6de1f8f60ab2140d4.tar.gz
spark-21c562fa03430365f5c2b7d6de1f8f60ab2140d4.tar.bz2
spark-21c562fa03430365f5c2b7d6de1f8f60ab2140d4.zip
[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up (3)
This PR is a 2nd follow-up for [SPARK-9241](https://issues.apache.org/jira/browse/SPARK-9241). It contains the following improvements: * Fix for a potential bug in distinct child expression and attribute alignment. * Improved handling of duplicate distinct child expressions. * Added test for distinct UDAF with multiple children. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #9566 from hvanhovell/SPARK-9241-followup-2.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala41
2 files changed, 42 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 397eff0568..c0c960471a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
}
// Setup unique distinct aggregate children.
- val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
- val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
- val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+ val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
+ val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
+ val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup expand & aggregate operators for distinct aggregate expressions.
+ val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
@@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
- evalWithinGroup(id, distinctAggChildAttrMap(x))
+ evalWithinGroup(id, distinctAggChildAttrLookup(x))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 6bf2c53440..8253921563 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
}
}
+class LongProductSum extends UserDefinedAggregateFunction {
+ def inputSchema: StructType = new StructType()
+ .add("a", LongType)
+ .add("b", LongType)
+
+ def bufferSchema: StructType = new StructType()
+ .add("product", LongType)
+
+ def dataType: DataType = LongType
+
+ def deterministic: Boolean = true
+
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = 0L
+ }
+
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (!(input.isNullAt(0) || input.isNullAt(1))) {
+ buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
+ }
+ }
+
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
+ }
+
+ def evaluate(buffer: Row): Any =
+ buffer.getLong(0)
+}
+
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// Register UDAFs
sqlContext.udf.register("mydoublesum", new MyDoubleSum)
sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
+ sqlContext.udf.register("longProductSum", new LongProductSum)
}
override def afterAll(): Unit = {
@@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
| count(distinct value2),
| sum(distinct value2),
| count(distinct value1, value2),
+ | longProductSum(distinct value1, value2),
| count(value1),
| sum(value1),
| count(value2),
| sum(value2),
+ | longProductSum(value1, value2),
| count(*),
| count(1)
|FROM agg2
|GROUP BY key
""".stripMargin),
- Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
- Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
- Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
- Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
+ Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
+ Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
+ Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
+ Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
}
test("test count") {