aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-10-05 16:05:30 -0700
committerYin Huai <yhuai@databricks.com>2016-10-05 16:05:30 -0700
commit5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3 (patch)
treec3578544cb4d4b4431a5debd0ce6ea7ff4334e0b
parent221b418b1c9db7b04c600b6300d18b034a4f444e (diff)
downloadspark-5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3.tar.gz
spark-5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3.tar.bz2
spark-5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3.zip
[SPARK-17758][SQL] Last returns wrong result in case of empty partition
## What changes were proposed in this pull request? The result of the `Last` function can be wrong when the last partition processed is empty. It can return `null` instead of the expected value. For example, this can happen when we process partitions in the following order: ``` - Partition 1 [Row1, Row2] - Partition 2 [Row3] - Partition 3 [] ``` In this case the `Last` function will currently return a null, instead of the value of `Row3`. This PR fixes this by adding a `valueSet` flag to the `Last` function. ## How was this patch tested? We only used end to end tests for `DeclarativeAggregateFunction`s. I have added an evaluator for these functions so we can tests them in catalyst. I have added a `LastTestSuite` to test the `Last` aggregate function. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #15348 from hvanhovell/SPARK-17758.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala109
3 files changed, 184 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index af88403058..8579f7292d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
private lazy val last = AttributeReference("last", child.dataType)()
- override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
+ private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
+
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil
override lazy val initialValues: Seq[Literal] = Seq(
- /* last = */ Literal.create(null, child.dataType)
+ /* last = */ Literal.create(null, child.dataType),
+ /* valueSet = */ Literal.create(false, BooleanType)
)
override lazy val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
- /* last = */ If(IsNull(child), last, child)
+ /* last = */ If(IsNull(child), last, child),
+ /* valueSet = */ Or(valueSet, IsNotNull(child))
)
} else {
Seq(
- /* last = */ child
+ /* last = */ child,
+ /* valueSet = */ Literal.create(true, BooleanType)
)
}
}
override lazy val mergeExpressions: Seq[Expression] = {
- if (ignoreNulls) {
- Seq(
- /* last = */ If(IsNull(last.right), last.left, last.right)
- )
- } else {
- Seq(
- /* last = */ last.right
- )
- }
+ // Prefer the right hand expression if it has been set.
+ Seq(
+ /* last = */ If(valueSet.right, last.right, last.left),
+ /* valueSet = */ Or(valueSet.right, valueSet.left)
+ )
}
override lazy val evaluateExpression: AttributeReference = last
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
new file mode 100644
index 0000000000..614f24db0a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
+
+/**
+ * Evaluator for a [[DeclarativeAggregate]].
+ */
+case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) {
+
+ lazy val initializer = GenerateSafeProjection.generate(function.initialValues)
+
+ lazy val updater = GenerateSafeProjection.generate(
+ function.updateExpressions,
+ function.aggBufferAttributes ++ input)
+
+ lazy val merger = GenerateSafeProjection.generate(
+ function.mergeExpressions,
+ function.aggBufferAttributes ++ function.inputAggBufferAttributes)
+
+ lazy val evaluator = GenerateSafeProjection.generate(
+ function.evaluateExpression :: Nil,
+ function.aggBufferAttributes)
+
+ def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy()
+
+ def update(values: InternalRow*): InternalRow = {
+ val joiner = new JoinedRow
+ val buffer = values.foldLeft(initialize()) { (buffer, input) =>
+ updater(joiner(buffer, input))
+ }
+ buffer.copy()
+ }
+
+ def merge(buffers: InternalRow*): InternalRow = {
+ val joiner = new JoinedRow
+ val buffer = buffers.foldLeft(initialize()) { (left, right) =>
+ merger(joiner(left, right))
+ }
+ buffer.copy()
+ }
+
+ def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy()
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
new file mode 100644
index 0000000000..ba36bc074e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
+import org.apache.spark.sql.types.IntegerType
+
+class LastTestSuite extends SparkFunSuite {
+ val input = AttributeReference("input", IntegerType, nullable = true)()
+ val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input))
+ val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input))
+
+ test("empty buffer") {
+ assert(evaluator.initialize() === InternalRow(null, false))
+ }
+
+ test("update") {
+ val result = evaluator.update(
+ InternalRow(1),
+ InternalRow(9),
+ InternalRow(-1))
+ assert(result === InternalRow(-1, true))
+ }
+
+ test("update - ignore nulls") {
+ val result1 = evaluatorIgnoreNulls.update(
+ InternalRow(null),
+ InternalRow(9),
+ InternalRow(null))
+ assert(result1 === InternalRow(9, true))
+
+ val result2 = evaluatorIgnoreNulls.update(
+ InternalRow(null),
+ InternalRow(null))
+ assert(result2 === InternalRow(null, false))
+ }
+
+ test("merge") {
+ // Empty merge
+ val p0 = evaluator.initialize()
+ assert(evaluator.merge(p0) === InternalRow(null, false))
+
+ // Single merge
+ val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
+ assert(evaluator.merge(p1) === p1)
+
+ // Multiple merges.
+ val p2 = evaluator.update(InternalRow(2), InternalRow(10))
+ assert(evaluator.merge(p1, p2) === p2)
+
+ // Empty partitions (p0 is empty)
+ assert(evaluator.merge(p1, p0, p2) === p2)
+ assert(evaluator.merge(p2, p1, p0) === p1)
+ }
+
+ test("merge - ignore nulls") {
+ // Multi merges
+ val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
+ val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
+ assert(evaluatorIgnoreNulls.merge(p1, p2) === p1)
+ }
+
+ test("eval") {
+ // Null Eval
+ assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null))
+ assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null))
+
+ // Empty Eval
+ val p0 = evaluator.initialize()
+ assert(evaluator.eval(p0) === InternalRow(null))
+
+ // Update - Eval
+ val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
+ assert(evaluator.eval(p1) === InternalRow(-99))
+
+ // Update - Merge - Eval
+ val p2 = evaluator.update(InternalRow(2), InternalRow(10))
+ val m1 = evaluator.merge(p1, p0, p2)
+ assert(evaluator.eval(m1) === InternalRow(10))
+
+ // Update - Merge - Eval (empty partition at the end)
+ val m2 = evaluator.merge(p2, p1, p0)
+ assert(evaluator.eval(m2) === InternalRow(-99))
+ }
+
+ test("eval - ignore nulls") {
+ // Update - Merge - Eval
+ val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
+ val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
+ val m1 = evaluatorIgnoreNulls.merge(p1, p2)
+ assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1))
+ }
+}