aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-10-21 13:43:17 -0700
committerMichael Armbrust <michael@databricks.com>2015-10-21 13:43:17 -0700
commit3afe448d39dc4877b2f2c62b3059aeb3ced0bd96 (patch)
tree4923d317fbcbeb4c0ca9dbc257406708ce705241 /sql
parentf8c6bec65784de89b47e96a367d3f9790c1b3115 (diff)
downloadspark-3afe448d39dc4877b2f2c62b3059aeb3ced0bd96.tar.gz
spark-3afe448d39dc4877b2f2c62b3059aeb3ced0bd96.tar.bz2
spark-3afe448d39dc4877b2f2c62b3059aeb3ced0bd96.zip
[SPARK-9740][SPARK-9592][SPARK-9210][SQL] Change the default behavior of First/Last to RESPECT NULLS.
I am changing the default behavior of `First`/`Last` to respect null values (the SQL standard default behavior). https://issues.apache.org/jira/browse/SPARK-9740 Author: Yin Huai <yhuai@databricks.com> Closes #8113 from yhuai/firstLast.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala105
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala95
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala38
7 files changed, 219 insertions, 45 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index ab215407f7..98d6637c06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -113,7 +113,8 @@ trait CheckAnalysis {
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
- "Add to group by or wrap in first() if you don't care which value you get.")
+ "Add to group by or wrap in first() (or first_value) if you don't care " +
+ "which value you get.")
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ba77b70a37..f73b24e363 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -179,7 +179,9 @@ object FunctionRegistry {
expression[Average]("avg"),
expression[Count]("count"),
expression[First]("first"),
+ expression[First]("first_value"),
expression[Last]("last"),
+ expression[Last]("last_value"),
expression[Max]("max"),
expression[Min]("min"),
expression[Stddev]("stddev"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index c0bc7ec09c..515246d344 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -21,6 +21,8 @@ import java.lang.{Long => JLong}
import java.util
import com.clearspring.analytics.hash.MurmurHash
+
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
@@ -118,7 +120,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = Cast(currentCount, LongType)
}
-case class First(child: Expression) extends DeclarativeAggregate {
+/**
+ * Returns the first value of `child` for a group of rows. If the first value of `child`
+ * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already
+ * sorted column, if we do partial aggregation and final aggregation (when mergeExpression
+ * is used) its result will not be deterministic (unless the input table is sorted and has
+ * a single partition, and we use a single reducer to do the aggregation.).
+ * @param child
+ */
+case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+
+ def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+ private val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case Literal(b: Boolean, BooleanType) => b
+ case _ =>
+ throw new AnalysisException("The second argument of First should be a boolean literal.")
+ }
override def children: Seq[Expression] = child :: Nil
@@ -135,24 +153,61 @@ case class First(child: Expression) extends DeclarativeAggregate {
private val first = AttributeReference("first", child.dataType)()
- override val aggBufferAttributes = first :: Nil
+ private val valueSet = AttributeReference("valueSet", BooleanType)()
+
+ override val aggBufferAttributes = first :: valueSet :: Nil
override val initialValues = Seq(
- /* first = */ Literal.create(null, child.dataType)
+ /* first = */ Literal.create(null, child.dataType),
+ /* valueSet = */ Literal.create(false, BooleanType)
)
- override val updateExpressions = Seq(
- /* first = */ If(IsNull(first), child, first)
- )
+ override val updateExpressions = {
+ if (ignoreNulls) {
+ Seq(
+ /* first = */ If(Or(valueSet, IsNull(child)), first, child),
+ /* valueSet = */ Or(valueSet, IsNotNull(child))
+ )
+ } else {
+ Seq(
+ /* first = */ If(valueSet, first, child),
+ /* valueSet = */ Literal.create(true, BooleanType)
+ )
+ }
+ }
- override val mergeExpressions = Seq(
- /* first = */ If(IsNull(first.left), first.right, first.left)
- )
+ override val mergeExpressions = {
+ // For first, we can just check if valueSet.left is set to true. If it is set
+ // to true, we use first.right. If not, we use first.right (even if valueSet.right is
+ // false, we are safe to do so because first.right will be null in this case).
+ Seq(
+ /* first = */ If(valueSet.left, first.left, first.right),
+ /* valueSet = */ Or(valueSet.left, valueSet.right)
+ )
+ }
override val evaluateExpression = first
+
+ override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}"
}
-case class Last(child: Expression) extends DeclarativeAggregate {
+/**
+ * Returns the last value of `child` for a group of rows. If the last value of `child`
+ * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already
+ * sorted column, if we do partial aggregation and final aggregation (when mergeExpression
+ * is used) its result will not be deterministic (unless the input table is sorted and has
+ * a single partition, and we use a single reducer to do the aggregation.).
+ * @param child
+ */
+case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+
+ def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+ private val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case Literal(b: Boolean, BooleanType) => b
+ case _ =>
+ throw new AnalysisException("The second argument of First should be a boolean literal.")
+ }
override def children: Seq[Expression] = child :: Nil
@@ -175,15 +230,33 @@ case class Last(child: Expression) extends DeclarativeAggregate {
/* last = */ Literal.create(null, child.dataType)
)
- override val updateExpressions = Seq(
- /* last = */ If(IsNull(child), last, child)
- )
+ override val updateExpressions = {
+ if (ignoreNulls) {
+ Seq(
+ /* last = */ If(IsNull(child), last, child)
+ )
+ } else {
+ Seq(
+ /* last = */ child
+ )
+ }
+ }
- override val mergeExpressions = Seq(
- /* last = */ If(IsNull(last.right), last.left, last.right)
- )
+ override val mergeExpressions = {
+ if (ignoreNulls) {
+ Seq(
+ /* last = */ If(IsNull(last.right), last.left, last.right)
+ )
+ } else {
+ Seq(
+ /* last = */ last.right
+ )
+ }
+ }
override val evaluateExpression = last
+
+ override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
}
case class Max(child: Expression) extends DeclarativeAggregate {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index f656ccf13b..12bdab0915 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -61,15 +61,15 @@ object Utils {
mode = aggregate.Complete,
isDistinct = true)
- case expressions.First(child) =>
+ case expressions.First(child, ignoreNulls) =>
aggregate.AggregateExpression2(
- aggregateFunction = aggregate.First(child),
+ aggregateFunction = aggregate.First(child, ignoreNulls),
mode = aggregate.Complete,
isDistinct = false)
- case expressions.Last(child) =>
+ case expressions.Last(child, ignoreNulls) =>
aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Last(child),
+ aggregateFunction = aggregate.Last(child, ignoreNulls),
mode = aggregate.Complete,
isDistinct = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index f1c47f3904..95061c4635 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
@@ -630,59 +631,113 @@ case class CombineSetsAndSumFunction(
}
}
-case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
+case class First(
+ child: Expression,
+ ignoreNullsExpr: Expression)
+ extends UnaryExpression with PartialAggregate1 {
+
+ def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+ private val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case Literal(b: Boolean, BooleanType) => b
+ case _ =>
+ throw new AnalysisException("The second argument of First should be a boolean literal.")
+ }
+
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
- override def toString: String = s"FIRST($child)"
+ override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})"
override def asPartial: SplitEvaluation = {
- val partialFirst = Alias(First(child), "PartialFirst")()
+ val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
SplitEvaluation(
- First(partialFirst.toAttribute),
+ First(partialFirst.toAttribute, ignoreNulls),
partialFirst :: Nil)
}
- override def newInstance(): FirstFunction = new FirstFunction(child, this)
+ override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
}
-case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
+object First {
+ def apply(child: Expression): First = First(child, ignoreNulls = false)
- var result: Any = null
+ def apply(child: Expression, ignoreNulls: Boolean): First =
+ First(child, Literal.create(ignoreNulls, BooleanType))
+}
+
+case class FirstFunction(
+ expr: Expression,
+ ignoreNulls: Boolean,
+ base: AggregateExpression1)
+ extends AggregateFunction1 {
+
+ def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
+
+ private[this] var result: Any = null
+
+ private[this] var valueSet: Boolean = false
override def update(input: InternalRow): Unit = {
- // We ignore null values.
- if (result == null) {
- result = expr.eval(input)
+ if (!valueSet) {
+ val value = expr.eval(input)
+ // When we have not set the result, we will set the result if we respect nulls
+ // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
+ if (!ignoreNulls || (ignoreNulls && value != null)) {
+ result = value
+ valueSet = true
+ }
}
}
override def eval(input: InternalRow): Any = result
}
-case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
+case class Last(
+ child: Expression,
+ ignoreNullsExpr: Expression)
+ extends UnaryExpression with PartialAggregate1 {
+
+ def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+ private val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case Literal(b: Boolean, BooleanType) => b
+ case _ =>
+ throw new AnalysisException("The second argument of First should be a boolean literal.")
+ }
+
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
- override def toString: String = s"LAST($child)"
+ override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
override def asPartial: SplitEvaluation = {
- val partialLast = Alias(Last(child), "PartialLast")()
+ val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
SplitEvaluation(
- Last(partialLast.toAttribute),
+ Last(partialLast.toAttribute, ignoreNulls),
partialLast :: Nil)
}
- override def newInstance(): LastFunction = new LastFunction(child, this)
+ override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
}
-case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
- def this() = this(null, null) // Required for serialization.
+object Last {
+ def apply(child: Expression): Last = Last(child, ignoreNulls = false)
+
+ def apply(child: Expression, ignoreNulls: Boolean): Last =
+ Last(child, Literal.create(ignoreNulls, BooleanType))
+}
+
+case class LastFunction(
+ expr: Expression,
+ ignoreNulls: Boolean,
+ base: AggregateExpression1)
+ extends AggregateFunction1 {
+
+ def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
var result: Any = null
override def update(input: InternalRow): Unit = {
val value = expr.eval(input)
- // We ignore null values.
- if (value != null) {
+ if (!ignoreNulls || (ignoreNulls && value != null)) {
result = value
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index c3d2246297..8b9247adea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.{Column, catalyst}
import org.apache.spark.sql.catalyst.expressions._
@@ -149,13 +150,17 @@ class WindowSpec private[sql](
case Count(child) => WindowExpression(
UnresolvedWindowFunction("count", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case First(child) => WindowExpression(
+ case First(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value
- UnresolvedWindowFunction("first_value", child :: Nil),
+ UnresolvedWindowFunction(
+ "first_value",
+ child :: ignoreNulls :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Last(child) => WindowExpression(
+ case Last(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF last_value
- UnresolvedWindowFunction("last_value", child :: Nil),
+ UnresolvedWindowFunction(
+ "last_value",
+ child :: ignoreNulls :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Min(child) => WindowExpression(
UnresolvedWindowFunction("min", child :: Nil),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index c9e1bb1995..f38a3f63c3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -323,6 +323,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(11.125) :: Nil)
}
+ test("first_value and last_value") {
+ // We force to use a single partition for the sort and aggregate to make result
+ // deterministic.
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | first_valUE(key),
+ | lasT_value(key),
+ | firSt(key),
+ | lASt(key),
+ | first_valUE(key, true),
+ | lasT_value(key, true),
+ | firSt(key, true),
+ | lASt(key, true)
+ |FROM (SELECT key FROM agg1 ORDER BY key) tmp
+ """.stripMargin),
+ Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | first_valUE(key),
+ | lasT_value(key),
+ | firSt(key),
+ | lASt(key),
+ | first_valUE(key, true),
+ | lasT_value(key, true),
+ | firSt(key, true),
+ | lASt(key, true)
+ |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp
+ """.stripMargin),
+ Row(3, null, 3, null, 3, 1, 3, 1) :: Nil)
+ }
+ }
+
test("udaf") {
checkAnswer(
sqlContext.sql(