aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorravipesala <ravindra.pesala@huawei.com>2014-12-18 20:19:10 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-18 20:19:10 -0800
commit7687415c2578b5bdc79c9646c246e52da9a4dd4a (patch)
treeb3223db3b78f0e7f10feb551fff346145bf8da47 /sql/catalyst
parente7de7e5f46821e1ba3b070b21d6bcf6d5ec8a796 (diff)
downloadspark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.tar.gz
spark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.tar.bz2
spark-7687415c2578b5bdc79c9646c246e52da9a4dd4a.zip
[SPARK-2554][SQL] Supporting SumDistinct partial aggregation
Adding support to the partial aggregation of SumDistinct Author: ravipesala <ravindra.pesala@huawei.com> Closes #3348 from ravipesala/SPARK-2554 and squashes the following commits: fd28e4d [ravipesala] Fixed review comments e60e67f [ravipesala] Fixed test cases and made it as nullable 32fe234 [ravipesala] Supporting SumDistinct partial aggregation Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
Diffstat (limited to 'sql/catalyst')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala53
1 files changed, 49 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 0cd90866e1..5ea9868e9e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -361,10 +361,10 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
case class SumDistinct(child: Expression)
- extends AggregateExpression with trees.UnaryNode[Expression] {
+ extends PartialAggregate with trees.UnaryNode[Expression] {
+ def this() = this(null)
override def nullable = true
-
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
@@ -373,10 +373,55 @@ case class SumDistinct(child: Expression)
case _ =>
child.dataType
}
+ override def toString = s"SUM(DISTINCT ${child})"
+ override def newInstance() = new SumDistinctFunction(child, this)
+
+ override def asPartial = {
+ val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
+ SplitEvaluation(
+ CombineSetsAndSum(partialSet.toAttribute, this),
+ partialSet :: Nil)
+ }
+}
- override def toString = s"SUM(DISTINCT $child)"
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+ def this() = this(null, null)
- override def newInstance() = new SumDistinctFunction(child, this)
+ override def children = inputSet :: Nil
+ override def nullable = true
+ override def dataType = base.dataType
+ override def toString = s"CombineAndSum($inputSet)"
+ override def newInstance() = new CombineSetsAndSumFunction(inputSet, this)
+}
+
+case class CombineSetsAndSumFunction(
+ @transient inputSet: Expression,
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ override def update(input: Row): Unit = {
+ val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+ val inputIterator = inputSetEval.iterator
+ while (inputIterator.hasNext) {
+ seen.add(inputIterator.next)
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ val casted = seen.asInstanceOf[OpenHashSet[Row]]
+ if (casted.size == 0) {
+ null
+ } else {
+ Cast(Literal(
+ casted.iterator.map(f => f.apply(0)).reduceLeft(
+ base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+ base.dataType).eval(null)
+ }
+ }
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {