aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-11-20 15:41:24 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-20 15:41:36 -0800
commit1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19 (patch)
tree276dd4919d01624f189f47f9efdb9b3757eaf8d8 /sql/catalyst
parent8608ff59881b3cfa6c4cd407ba2c0af7a78e88a9 (diff)
downloadspark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.tar.gz
spark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.tar.bz2
spark-1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19.zip
[SPARK-4318][SQL] Fix empty sum distinct.
Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits: 8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318 66fdb0a [Takuya UESHIN] Re-refine aggregate functions. 6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate. d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate. 1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions. 917e533 [Takuya UESHIN] Use aggregate instead of groupBy(). 1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation. a5a57d2 [Takuya UESHIN] Fix empty Average. 22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct. 65b7dd2 [Takuya UESHIN] Fix empty sum distinct. (cherry picked from commit 2c2e7a44db2ebe44121226f3eac924a0668b991a) Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'sql/catalyst')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala103
1 files changed, 79 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 3ceb5ecaf6..0cd90866e1 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
- SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
+ SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}
override def newInstance() = new CountFunction(child, this)
@@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
child.dataType match {
case DecimalType.Fixed(_, _) =>
- // Turn the results to unlimited decimals for the division, before going back to fixed
+ // Turn the child to unlimited decimals for calculation, before going back to fixed
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ val partialCount = Alias(Count(child), "PartialCount")()
+
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
SplitEvaluation(
@@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ val partialCount = Alias(Count(child), "PartialCount")()
+
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
SplitEvaluation(
@@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Sum(partialSum.toAttribute),
- partialSum :: Nil)
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ SplitEvaluation(
+ Cast(Sum(partialSum.toAttribute), dataType),
+ partialSum :: Nil)
+
+ case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ SplitEvaluation(
+ Sum(partialSum.toAttribute),
+ partialSum :: Nil)
+ }
}
override def newInstance() = new SumFunction(child, this)
@@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
- private val zero = Cast(Literal(0), expr.dataType)
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
private var count: Long = _
- private val sum = MutableLiteral(zero.eval(null), expr.dataType)
- private val sumAsDouble = Cast(sum, DoubleType)
+ private val sum = MutableLiteral(zero.eval(null), calcType)
- private def addFunction(value: Any) = Add(sum, Literal(value))
+ private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType))
- override def eval(input: Row): Any =
- sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
+ override def eval(input: Row): Any = {
+ if (count == 0L) {
+ null
+ } else {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(Divide(
+ Cast(sum, DecimalType.Unlimited),
+ Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
+ case _ =>
+ Divide(
+ Cast(sum, dataType),
+ Cast(Literal(count), dataType)).eval(null)
+ }
+ }
+ }
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
@@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction(
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- private val zero = Cast(Literal(0), expr.dataType)
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
- private val sum = MutableLiteral(zero.eval(null), expr.dataType)
+ private val sum = MutableLiteral(null, calcType)
- private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+ private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
}
- override def eval(input: Row): Any = sum.eval(null)
+ override def eval(input: Row): Any = {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(sum, dataType).eval(null)
+ case _ => sum.eval(null)
+ }
+ }
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
@@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
- override def eval(input: Row): Any =
- seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
+ override def eval(input: Row): Any = {
+ if (seen.size == 0) {
+ null
+ } else {
+ Cast(Literal(
+ seen.reduceLeft(
+ dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+ dataType).eval(null)
+ }
+ }
}
case class CountDistinctFunction(