aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-05-27 14:55:23 -0700
committerReynold Xin <rxin@apache.org>2014-05-27 14:55:23 -0700
commit3b0babad1f0856ee16f9d58e1ead30779a4a6310 (patch)
tree9f191295ddce1a4554b7a12c201718e4c2703e70 /sql
parentd1375a2bff846f2c4274e14545924646852895f9 (diff)
downloadspark-3b0babad1f0856ee16f9d58e1ead30779a4a6310.tar.gz
spark-3b0babad1f0856ee16f9d58e1ead30779a4a6310.tar.bz2
spark-3b0babad1f0856ee16f9d58e1ead30779a4a6310.zip
[SPARK-1915] [SQL] AverageFunction should not count if the evaluated value is null.
Average values are difference between the calculation is done partially or not partially. Because `AverageFunction` (in not-partially calculation) counts even if the evaluated value is null. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #862 from ueshin/issues/SPARK-1915 and squashes the following commits: b1ff3c0 [Takuya UESHIN] Modify AverageFunction not to count if the evaluated value is null.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala10
2 files changed, 16 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index b49a4614ea..c902433688 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -281,14 +281,17 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)
- private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+ private def addFunction(value: Any) = Add(sum, Literal(value))
override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
override def update(input: Row): Unit = {
- count += 1
- sum.update(addFunction, input)
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ count += 1
+ sum.update(addFunction(evaluatedExpr), input)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 8197e8a18d..fb599e1e01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -115,6 +115,16 @@ class DslQuerySuite extends QueryTest {
2.0)
}
+ test("null average") {
+ checkAnswer(
+ testData3.groupBy()(Average('b)),
+ 2.0)
+
+ checkAnswer(
+ testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
+ (2.0, 1) :: Nil)
+ }
+
test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())
}