aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-11-20 15:41:24 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-20 15:41:24 -0800
commit2c2e7a44db2ebe44121226f3eac924a0668b991a (patch)
tree3989b7928ea90368f01a8d613205207eb8581d5a /sql/core/src
parent98e9419784a9ad5096cfd563fa9a433786a90bd4 (diff)
downloadspark-2c2e7a44db2ebe44121226f3eac924a0668b991a.tar.gz
spark-2c2e7a44db2ebe44121226f3eac924a0668b991a.tar.bz2
spark-2c2e7a44db2ebe44121226f3eac924a0668b991a.zip
[SPARK-4318][SQL] Fix empty sum distinct.
Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits: 8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318 66fdb0a [Takuya UESHIN] Re-refine aggregate functions. 6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate. d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate. 1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions. 917e533 [Takuya UESHIN] Use aggregate instead of groupBy(). 1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation. a5a57d2 [Takuya UESHIN] Fix empty Average. 22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct. 65b7dd2 [Takuya UESHIN] Fix empty sum distinct.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala65
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala11
3 files changed, 116 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 087b0ecbb2..18afc5d741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -83,29 +83,45 @@ case class GeneratedAggregate(
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
- case Sum(expr) =>
- val resultType = expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale)
- case _ =>
- expr.dataType
- }
+ case s @ Sum(expr) =>
+ val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
- val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
- val initialValue = Cast(Literal(0L), resultType)
+ val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
+ val initialValue = Literal(null, calcType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
- val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
- val result = currentSum
+ val zero = Cast(Literal(0), calcType)
+ val updateFunction = Coalesce(
+ Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
+ val result =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(currentSum, s.dataType)
+ case _ => currentSum
+ }
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case a @ Average(expr) =>
+ val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
- val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val currentSum = AttributeReference("currentSum", calcType, nullable = false)()
val initialCount = Literal(0L)
- val initialSum = Cast(Literal(0L), expr.dataType)
+ val initialSum = Cast(Literal(0L), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
@@ -115,17 +131,21 @@ case class GeneratedAggregate(
}
val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
- val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
-
- val resultType = expr.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 4, scale + 4)
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- DoubleType
- }
- val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
+ val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil)
+
+ val result =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ If(EqualTo(currentCount, Literal(0L)),
+ Literal(null, a.dataType),
+ Cast(Divide(
+ Cast(currentSum, DecimalType.Unlimited),
+ Cast(currentCount, DecimalType.Unlimited)), a.dataType))
+ case _ =>
+ If(EqualTo(currentCount, Literal(0L)),
+ Literal(null, a.dataType),
+ Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType)))
+ }
AggregateEvaluation(
currentCount :: currentSum :: Nil,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index e70ad891ee..94bd97758f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest {
test("average") {
checkAnswer(
- testData2.groupBy()(avg('a)),
+ testData2.aggregate(avg('a)),
2.0)
+
+ checkAnswer(
+ testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+ (2.0, 6.0) :: Nil)
+
+ checkAnswer(
+ decimalData.aggregate(avg('a)),
+ BigDecimal(2.0))
+ checkAnswer(
+ decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+ (BigDecimal(2.0), BigDecimal(6)) :: Nil)
+
+ checkAnswer(
+ decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+ BigDecimal(2.0))
+ checkAnswer(
+ decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
+ (BigDecimal(2.0), BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
- testData3.groupBy()(avg('b)),
+ testData3.aggregate(avg('b)),
2.0)
checkAnswer(
- testData3.groupBy()(avg('b), countDistinct('b)),
+ testData3.aggregate(avg('b), countDistinct('b)),
(2.0, 1) :: Nil)
+
+ checkAnswer(
+ testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+ (2.0, 2.0) :: Nil)
+ }
+
+ test("zero average") {
+ checkAnswer(
+ emptyTableData.aggregate(avg('a)),
+ null)
+
+ checkAnswer(
+ emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+ (null, null) :: Nil)
}
test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())
+
+ checkAnswer(
+ testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+ (6, 6.0) :: Nil)
}
test("null count") {
@@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest {
)
checkAnswer(
- testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
+ testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
(2, 1, 2, 2, 1) :: Nil
)
+
+ checkAnswer(
+ testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+ (1, 1, 2) :: Nil
+ )
}
test("zero count") {
assert(emptyTableData.count() === 0)
+
+ checkAnswer(
+ emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+ (0, null) :: Nil)
+ }
+
+ test("zero sum") {
+ checkAnswer(
+ emptyTableData.aggregate(sum('a)),
+ null)
+ }
+
+ test("zero sum distinct") {
+ checkAnswer(
+ emptyTableData.aggregate(sumDistinct('a)),
+ null)
}
test("except") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 92b49e8155..933e027436 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -54,6 +54,17 @@ object TestData {
TestData2(3, 2) :: Nil).toSchemaRDD
testData2.registerTempTable("testData2")
+ case class DecimalData(a: BigDecimal, b: BigDecimal)
+ val decimalData =
+ TestSQLContext.sparkContext.parallelize(
+ DecimalData(1, 1) ::
+ DecimalData(1, 2) ::
+ DecimalData(2, 1) ::
+ DecimalData(2, 2) ::
+ DecimalData(3, 1) ::
+ DecimalData(3, 2) :: Nil).toSchemaRDD
+ decimalData.registerTempTable("decimalData")
+
case class BinaryData(a: Array[Byte], b: Int)
val binaryData =
TestSQLContext.sparkContext.parallelize(