diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2014-11-20 15:41:24 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-20 15:41:36 -0800 |
commit | 1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19 (patch) | |
tree | 276dd4919d01624f189f47f9efdb9b3757eaf8d8 /sql/catalyst | |
parent | 8608ff59881b3cfa6c4cd407ba2c0af7a78e88a9 (diff) | |
download | spark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.tar.gz spark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.tar.bz2 spark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.zip |
[SPARK-4318][SQL] Fix empty sum distinct.
Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits:
8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318
66fdb0a [Takuya UESHIN] Re-refine aggregate functions.
6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate.
d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate.
1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions.
917e533 [Takuya UESHIN] Use aggregate instead of groupBy().
1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation.
a5a57d2 [Takuya UESHIN] Fix empty Average.
22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct.
65b7dd2 [Takuya UESHIN] Fix empty sum distinct.
(cherry picked from commit 2c2e7a44db2ebe44121226f3eac924a0668b991a)
Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'sql/catalyst')
-rwxr-xr-x | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala | 103 |
1 files changed, 79 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 3ceb5ecaf6..0cd90866e1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) + SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } override def newInstance() = new CountFunction(child, this) @@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - child.dataType match { case DecimalType.Fixed(_, _) => - // Turn the results to unlimited decimals for the division, before going back to fixed + // Turn the child to unlimited decimals for calculation, before going back to fixed + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) SplitEvaluation( @@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) val castedCount = Cast(Sum(partialCount.toAttribute), dataType) SplitEvaluation( @@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(Sum(partialSum.toAttribute), dataType), + partialSum :: Nil) + + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + Sum(partialSum.toAttribute), + partialSum :: Nil) + } } override def newInstance() = new SumFunction(child, this) @@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. - private val zero = Cast(Literal(0), expr.dataType) + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), expr.dataType) - private val sumAsDouble = Cast(sum, DoubleType) + private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Literal(value)) + private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) - override def eval(input: Row): Any = - sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble + override def eval(input: Row): Any = { + if (count == 0L) { + null + } else { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(Divide( + Cast(sum, DecimalType.Unlimited), + Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) + case _ => + Divide( + Cast(sum, dataType), + Cast(Literal(count), dataType)).eval(null) + } + } + } override def update(input: Row): Unit = { val evaluatedExpr = expr.eval(input) @@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction( case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val zero = Cast(Literal(0), expr.dataType) + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) - private val sum = MutableLiteral(zero.eval(null), expr.dataType) + private val sum = MutableLiteral(null, calcType) - private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: Row): Unit = { sum.update(addFunction, input) } - override def eval(input: Row): Any = sum.eval(null) + override def eval(input: Row): Any = { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(sum, dataType).eval(null) + case _ => sum.eval(null) + } + } } case class SumDistinctFunction(expr: Expression, base: AggregateExpression) @@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } - override def eval(input: Row): Any = - seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) + override def eval(input: Row): Any = { + if (seen.size == 0) { + null + } else { + Cast(Literal( + seen.reduceLeft( + dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + dataType).eval(null) + } + } } case class CountDistinctFunction( |