aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-17 15:30:50 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-17 15:30:50 -0700
commit772e7c18fb1a79c0f080408cb43307fe89a4fa04 (patch)
treeb73cde0d245ff87f312ba7ce8000a238831c92e9
parentb265e282b62954548740a5767e97ab1678c65194 (diff)
downloadspark-772e7c18fb1a79c0f080408cb43307fe89a4fa04.tar.gz
spark-772e7c18fb1a79c0f080408cb43307fe89a4fa04.tar.bz2
spark-772e7c18fb1a79c0f080408cb43307fe89a4fa04.zip
[SPARK-9592] [SQL] Fix Last function implemented based on AggregateExpression1.
https://issues.apache.org/jira/browse/SPARK-9592 #8113 has the fundamental fix. But, if we want to minimize the number of changed lines, we can go with this one. Then, in 1.6, we merge #8113. Author: Yin Huai <yhuai@databricks.com> Closes #8172 from yhuai/lastFix and squashes the following commits: b28c42a [Yin Huai] Regression test. af87086 [Yin Huai] Fix last.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala15
2 files changed, 22 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 2cf8312ea5..5e8298aaaa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
var result: Any = null
override def update(input: InternalRow): Unit = {
+ // We ignore null values.
if (result == null) {
result = expr.eval(input)
}
@@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
var result: Any = null
override def update(input: InternalRow): Unit = {
- result = input
+ val value = expr.eval(input)
+ // We ignore null values.
+ if (value != null) {
+ result = value
+ }
}
override def eval(input: InternalRow): Any = {
- if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null
+ result
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a312f84958..119663af18 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -480,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Row(0, null, 1, 1, null, 0) :: Nil)
}
+ test("test Last implemented based on AggregateExpression1") {
+ // TODO: Remove this test once we remove AggregateExpression1.
+ import org.apache.spark.sql.functions._
+ val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
+ withSQLConf(
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+ SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+
+ checkAnswer(
+ df.groupBy("i").agg(last("j")),
+ df
+ )
+ }
+ }
+
test("error handling") {
withSQLConf("spark.sql.useAggregate2" -> "false") {
val errorMessage = intercept[AnalysisException] {