From 5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 5 Oct 2016 16:05:30 -0700 Subject: [SPARK-17758][SQL] Last returns wrong result in case of empty partition ## What changes were proposed in this pull request? The result of the `Last` function can be wrong when the last partition processed is empty. It can return `null` instead of the expected value. For example, this can happen when we process partitions in the following order: ``` - Partition 1 [Row1, Row2] - Partition 2 [Row3] - Partition 3 [] ``` In this case the `Last` function will currently return a null, instead of the value of `Row3`. This PR fixes this by adding a `valueSet` flag to the `Last` function. ## How was this patch tested? We only used end to end tests for `DeclarativeAggregateFunction`s. I have added an evaluator for these functions so we can tests them in catalyst. I have added a `LastTestSuite` to test the `Last` aggregate function. Author: Herman van Hovell Closes #15348 from hvanhovell/SPARK-17758. --- .../sql/catalyst/expressions/aggregate/Last.scala | 27 ++--- .../aggregate/DeclarativeAggregateEvaluator.scala | 61 ++++++++++++ .../expressions/aggregate/LastTestSuite.scala | 109 +++++++++++++++++++++ 3 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index af88403058..8579f7292d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat private lazy val last = AttributeReference("last", child.dataType)() - override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil override lazy val initialValues: Seq[Literal] = Seq( - /* last = */ Literal.create(null, child.dataType) + /* last = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child) + /* last = */ If(IsNull(child), last, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) ) } else { Seq( - /* last = */ child + /* last = */ child, + /* valueSet = */ Literal.create(true, BooleanType) ) } } override lazy val mergeExpressions: Seq[Expression] = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - } else { - Seq( - /* last = */ last.right - ) - } + // Prefer the right hand expression if it has been set. + Seq( + /* last = */ If(valueSet.right, last.right, last.left), + /* valueSet = */ Or(valueSet.right, valueSet.left) + ) } override lazy val evaluateExpression: AttributeReference = last diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala new file mode 100644 index 0000000000..614f24db0a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection + +/** + * Evaluator for a [[DeclarativeAggregate]]. + */ +case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { + + lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + + lazy val updater = GenerateSafeProjection.generate( + function.updateExpressions, + function.aggBufferAttributes ++ input) + + lazy val merger = GenerateSafeProjection.generate( + function.mergeExpressions, + function.aggBufferAttributes ++ function.inputAggBufferAttributes) + + lazy val evaluator = GenerateSafeProjection.generate( + function.evaluateExpression :: Nil, + function.aggBufferAttributes) + + def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + + def update(values: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = values.foldLeft(initialize()) { (buffer, input) => + updater(joiner(buffer, input)) + } + buffer.copy() + } + + def merge(buffers: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = buffers.foldLeft(initialize()) { (left, right) => + merger(joiner(left, right)) + } + buffer.copy() + } + + def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala new file mode 100644 index 0000000000..ba36bc074e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.IntegerType + +class LastTestSuite extends SparkFunSuite { + val input = AttributeReference("input", IntegerType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(null, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(1), + InternalRow(9), + InternalRow(-1)) + assert(result === InternalRow(-1, true)) + } + + test("update - ignore nulls") { + val result1 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(9), + InternalRow(null)) + assert(result1 === InternalRow(9, true)) + + val result2 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(null)) + assert(result2 === InternalRow(null, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(null, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.merge(p1) === p1) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + assert(evaluator.merge(p1, p2) === p2) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === p2) + assert(evaluator.merge(p2, p1, p0) === p1) + } + + test("merge - ignore nulls") { + // Multi merges + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + assert(evaluatorIgnoreNulls.merge(p1, p2) === p1) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.eval(p1) === InternalRow(-99)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(10)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(-99)) + } + + test("eval - ignore nulls") { + // Update - Merge - Eval + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + val m1 = evaluatorIgnoreNulls.merge(p1, p2) + assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1)) + } +} -- cgit v1.2.3