aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-11-20 15:41:24 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-20 15:41:24 -0800
commit2c2e7a44db2ebe44121226f3eac924a0668b991a (patch)
tree3989b7928ea90368f01a8d613205207eb8581d5a /sql
parent98e9419784a9ad5096cfd563fa9a433786a90bd4 (diff)
downloadspark-2c2e7a44db2ebe44121226f3eac924a0668b991a.tar.gz
spark-2c2e7a44db2ebe44121226f3eac924a0668b991a.tar.bz2
spark-2c2e7a44db2ebe44121226f3eac924a0668b991a.zip
[SPARK-4318][SQL] Fix empty sum distinct.
Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits: 8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318 66fdb0a [Takuya UESHIN] Re-refine aggregate functions. 6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate. d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate. 1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions. 917e533 [Takuya UESHIN] Use aggregate instead of groupBy(). 1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation. a5a57d2 [Takuya UESHIN] Fix empty Average. 22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct. 65b7dd2 [Takuya UESHIN] Fix empty sum distinct.
Diffstat (limited to 'sql')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala65
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala11
4 files changed, 195 insertions, 52 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 3ceb5ecaf6..0cd90866e1 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
- SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
+ SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}
override def newInstance() = new CountFunction(child, this)
@@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
- val partialSum = Alias(Sum(child), "PartialSum")()
- val partialCount = Alias(Count(child), "PartialCount")()
-
child.dataType match {
case DecimalType.Fixed(_, _) =>
- // Turn the results to unlimited decimals for the division, before going back to fixed
+ // Turn the child to unlimited decimals for calculation, before going back to fixed
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ val partialCount = Alias(Count(child), "PartialCount")()
+
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
SplitEvaluation(
@@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ val partialCount = Alias(Count(child), "PartialCount")()
+
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
SplitEvaluation(
@@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- Sum(partialSum.toAttribute),
- partialSum :: Nil)
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ SplitEvaluation(
+ Cast(Sum(partialSum.toAttribute), dataType),
+ partialSum :: Nil)
+
+ case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ SplitEvaluation(
+ Sum(partialSum.toAttribute),
+ partialSum :: Nil)
+ }
}
override def newInstance() = new SumFunction(child, this)
@@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
+ override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
- private val zero = Cast(Literal(0), expr.dataType)
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
private var count: Long = _
- private val sum = MutableLiteral(zero.eval(null), expr.dataType)
- private val sumAsDouble = Cast(sum, DoubleType)
+ private val sum = MutableLiteral(zero.eval(null), calcType)
- private def addFunction(value: Any) = Add(sum, Literal(value))
+ private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType))
- override def eval(input: Row): Any =
- sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
+ override def eval(input: Row): Any = {
+ if (count == 0L) {
+ null
+ } else {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(Divide(
+ Cast(sum, DecimalType.Unlimited),
+ Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
+ case _ =>
+ Divide(
+ Cast(sum, dataType),
+ Cast(Literal(count), dataType)).eval(null)
+ }
+ }
+ }
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
@@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction(
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- private val zero = Cast(Literal(0), expr.dataType)
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
- private val sum = MutableLiteral(zero.eval(null), expr.dataType)
+ private val sum = MutableLiteral(null, calcType)
- private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+ private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
}
- override def eval(input: Row): Any = sum.eval(null)
+ override def eval(input: Row): Any = {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(sum, dataType).eval(null)
+ case _ => sum.eval(null)
+ }
+ }
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
@@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
- override def eval(input: Row): Any =
- seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
+ override def eval(input: Row): Any = {
+ if (seen.size == 0) {
+ null
+ } else {
+ Cast(Literal(
+ seen.reduceLeft(
+ dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+ dataType).eval(null)
+ }
+ }
}
case class CountDistinctFunction(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 087b0ecbb2..18afc5d741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -83,29 +83,45 @@ case class GeneratedAggregate(
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
- case Sum(expr) =>
- val resultType = expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale)
- case _ =>
- expr.dataType
- }
+ case s @ Sum(expr) =>
+ val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
- val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
- val initialValue = Cast(Literal(0L), resultType)
+ val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
+ val initialValue = Literal(null, calcType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
- val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
- val result = currentSum
+ val zero = Cast(Literal(0), calcType)
+ val updateFunction = Coalesce(
+ Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
+ val result =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(currentSum, s.dataType)
+ case _ => currentSum
+ }
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case a @ Average(expr) =>
+ val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
- val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val currentSum = AttributeReference("currentSum", calcType, nullable = false)()
val initialCount = Literal(0L)
- val initialSum = Cast(Literal(0L), expr.dataType)
+ val initialSum = Cast(Literal(0L), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
@@ -115,17 +131,21 @@ case class GeneratedAggregate(
}
val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
- val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
-
- val resultType = expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 4, scale + 4)
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- DoubleType
- }
- val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
+ val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil)
+
+ val result =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ If(EqualTo(currentCount, Literal(0L)),
+ Literal(null, a.dataType),
+ Cast(Divide(
+ Cast(currentSum, DecimalType.Unlimited),
+ Cast(currentCount, DecimalType.Unlimited)), a.dataType))
+ case _ =>
+ If(EqualTo(currentCount, Literal(0L)),
+ Literal(null, a.dataType),
+ Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType)))
+ }
AggregateEvaluation(
currentCount :: currentSum :: Nil,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index e70ad891ee..94bd97758f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest {
test("average") {
checkAnswer(
- testData2.groupBy()(avg('a)),
+ testData2.aggregate(avg('a)),
2.0)
+
+ checkAnswer(
+ testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+ (2.0, 6.0) :: Nil)
+
+ checkAnswer(
+ decimalData.aggregate(avg('a)),
+ BigDecimal(2.0))
+ checkAnswer(
+ decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+ (BigDecimal(2.0), BigDecimal(6)) :: Nil)
+
+ checkAnswer(
+ decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+ BigDecimal(2.0))
+ checkAnswer(
+ decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
+ (BigDecimal(2.0), BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
- testData3.groupBy()(avg('b)),
+ testData3.aggregate(avg('b)),
2.0)
checkAnswer(
- testData3.groupBy()(avg('b), countDistinct('b)),
+ testData3.aggregate(avg('b), countDistinct('b)),
(2.0, 1) :: Nil)
+
+ checkAnswer(
+ testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+ (2.0, 2.0) :: Nil)
+ }
+
+ test("zero average") {
+ checkAnswer(
+ emptyTableData.aggregate(avg('a)),
+ null)
+
+ checkAnswer(
+ emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+ (null, null) :: Nil)
}
test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())
+
+ checkAnswer(
+ testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+ (6, 6.0) :: Nil)
}
test("null count") {
@@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest {
)
checkAnswer(
- testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
+ testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
(2, 1, 2, 2, 1) :: Nil
)
+
+ checkAnswer(
+ testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+ (1, 1, 2) :: Nil
+ )
}
test("zero count") {
assert(emptyTableData.count() === 0)
+
+ checkAnswer(
+ emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+ (0, null) :: Nil)
+ }
+
+ test("zero sum") {
+ checkAnswer(
+ emptyTableData.aggregate(sum('a)),
+ null)
+ }
+
+ test("zero sum distinct") {
+ checkAnswer(
+ emptyTableData.aggregate(sumDistinct('a)),
+ null)
}
test("except") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 92b49e8155..933e027436 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -54,6 +54,17 @@ object TestData {
TestData2(3, 2) :: Nil).toSchemaRDD
testData2.registerTempTable("testData2")
+ case class DecimalData(a: BigDecimal, b: BigDecimal)
+ val decimalData =
+ TestSQLContext.sparkContext.parallelize(
+ DecimalData(1, 1) ::
+ DecimalData(1, 2) ::
+ DecimalData(2, 1) ::
+ DecimalData(2, 2) ::
+ DecimalData(3, 1) ::
+ DecimalData(3, 2) :: Nil).toSchemaRDD
+ decimalData.registerTempTable("decimalData")
+
case class BinaryData(a: Array[Byte], b: Int)
val binaryData =
TestSQLContext.sparkContext.parallelize(