aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala5
2 files changed, 7 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1bcd4e2276..79937b129a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
var count: Long = _
override def update(input: Row): Unit = {
- val evaluatedExpr = expr.map(_.eval(input))
- if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
count += 1L
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 692569a73f..8197e8a18d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -126,6 +126,11 @@ class DslQuerySuite extends QueryTest {
)
checkAnswer(
+ testData3.groupBy('a)('a, Count('a + 'b)),
+ Seq((1,0), (2, 1))
+ )
+
+ checkAnswer(
testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
(2, 1, 2, 2, 1) :: Nil
)