diff options
author | ravipesala <ravindra.pesala@huawei.com> | 2014-12-18 20:19:10 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-18 20:19:10 -0800 |
commit | 7687415c2578b5bdc79c9646c246e52da9a4dd4a (patch) | |
tree | b3223db3b78f0e7f10feb551fff346145bf8da47 /sql | |
parent | e7de7e5f46821e1ba3b070b21d6bcf6d5ec8a796 (diff) | |
download | spark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.tar.gz spark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.tar.bz2 spark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.zip |
[SPARK-2554][SQL] Supporting SumDistinct partial aggregation
Adding support to the partial aggregation of SumDistinct
Author: ravipesala <ravindra.pesala@huawei.com>
Closes #3348 from ravipesala/SPARK-2554 and squashes the following commits:
fd28e4d [ravipesala] Fixed review comments
e60e67f [ravipesala] Fixed test cases and made it as nullable
32fe234 [ravipesala] Supporting SumDistinct partial aggregation Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
Diffstat (limited to 'sql')
-rwxr-xr-x | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala | 53 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 13 |
2 files changed, 58 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 0cd90866e1..5ea9868e9e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -361,10 +361,10 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } case class SumDistinct(child: Expression) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends PartialAggregate with trees.UnaryNode[Expression] { + def this() = this(null) override def nullable = true - override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive @@ -373,10 +373,55 @@ case class SumDistinct(child: Expression) case _ => child.dataType } + override def toString = s"SUM(DISTINCT ${child})" + override def newInstance() = new SumDistinctFunction(child, this) + + override def asPartial = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } +} - override def toString = s"SUM(DISTINCT $child)" +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) - override def newInstance() = new SumDistinctFunction(child, this) + override def children = inputSet :: Nil + override def nullable = true + override def dataType = base.dataType + override def toString = s"CombineAndSum($inputSet)" + override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: Row): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: Row): Any = { + val casted = seen.asInstanceOf[OpenHashSet[Row]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 96f3430207..f57f31af15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -185,9 +185,14 @@ class SQLQuerySuite extends QueryTest { sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), sql("SELECT 1 FROM src").collect().toSeq) } - - test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), - sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + } + + test("SPARK-2554 SumDistinct partial aggregation") { + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), + sql("SELECT distinct key FROM src order by key").collect().toSeq) } } |