aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala (renamed from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala)66
4 files changed, 108 insertions, 40 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d1d2c59cae..61162ccdba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1787,7 +1787,8 @@ class Analyzer(
s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
if wf.frame != UnspecifiedFrame =>
WindowExpression(wf, s.copy(frameSpecification = wf.frame))
- case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) =>
+ case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
+ if e.resolved =>
val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true)
we.copy(windowSpec = s.copy(frameSpecification = frame))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index e35192ca2d..6806591f68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -321,8 +321,7 @@ abstract class OffsetWindowFunction
val input: Expression
/**
- * Default result value for the function when the input expression returns NULL. The default will
- * evaluated against the current row instead of the offset row.
+ * Default result value for the function when the 'offset'th row does not exist.
*/
val default: Expression
@@ -348,7 +347,7 @@ abstract class OffsetWindowFunction
*/
override def foldable: Boolean = false
- override def nullable: Boolean = default == null || default.nullable
+ override def nullable: Boolean = default == null || default.nullable || input.nullable
override lazy val frame = {
// This will be triggered by the Analyzer.
@@ -373,20 +372,22 @@ abstract class OffsetWindowFunction
}
/**
- * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window.
- * Offsets start at 0, which is the current row. The offset must be constant integer value. The
- * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger
- * than the window, the default expression is evaluated.
- *
- * This documentation has been based upon similar documentation for the Hive and Presto projects.
+ * The Lead function returns the value of 'x' at the 'offset'th row after the current row in
+ * the window. Offsets start at 0, which is the current row. The offset must be constant
+ * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row,
+ * null is returned. If there is no such offset row, the default expression is evaluated.
*
* @param input expression to evaluate 'offset' rows after the current row.
* @param offset rows to jump ahead in the partition.
- * @param default to use when the input value is null or when the offset is larger than the window.
+ * @param default to use when the offset is larger than the window. The default value is null.
*/
@ExpressionDescription(usage =
- """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows
- after the current row in the window""")
+ """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 'offset'th row
+ after the current row in the window.
+ The default value of 'offset' is 1 and the default value of 'default' is null.
+ If the value of 'x' at the 'offset'th row is null, null is returned.
+ If there is no such offset row (e.g. when the offset is 1, the last row of the window
+ does not have any subsequent row), 'default' is returned.""")
case class Lead(input: Expression, offset: Expression, default: Expression)
extends OffsetWindowFunction {
@@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, default: Expression)
}
/**
- * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window.
- * Offsets start at 0, which is the current row. The offset must be constant integer value. The
- * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller
- * than the window, the default expression is evaluated.
- *
- * This documentation has been based upon similar documentation for the Hive and Presto projects.
+ * The Lag function returns the value of 'x' at the 'offset'th row before the current row in
+ * the window. Offsets start at 0, which is the current row. The offset must be constant
+ * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row,
+ * null is returned. If there is no such offset row, the default expression is evaluated.
*
* @param input expression to evaluate 'offset' rows before the current row.
* @param offset rows to jump back in the partition.
- * @param default to use when the input value is null or when the offset is smaller than the window.
+ * @param default to use when the offset row does not exist.
*/
@ExpressionDescription(usage =
- """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows
- before the current row in the window""")
+ """_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 'offset'th row
+ before the current row in the window.
+ The default value of 'offset' is 1 and the default value of 'default' is null.
+ If the value of 'x' at the 'offset'th row is null, null is returned.
+ If there is no such offset row (e.g. when the offset is 1, the first row of the window
+ does not have any previous row), 'default' is returned.""")
case class Lag(input: Expression, offset: Expression, default: Expression)
extends OffsetWindowFunction {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 93f007f5b3..7149603018 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -582,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame(
/** Row used to combine the offset and the current row. */
private[this] val join = new JoinedRow
- /** Create the projection. */
+ /**
+ * Create the projection used when the offset row exists.
+ * Please note that this project always respect null input values (like PostgreSQL).
+ */
private[this] val projection = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
- val numInputAttributes = inputAttrs.size
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
case e: OffsetWindowFunction =>
val input = BindReferences.bindReference(e.input, inputAttrs)
+ input
+ case e =>
+ BindReferences.bindReference(e, inputAttrs)
+ }
+
+ // Create the projection.
+ newMutableProjection(boundExpressions, Nil).target(target)
+ }
+
+ /** Create the projection used when the offset row DOES NOT exists. */
+ private[this] val fillDefaultValue = {
+ // Collect the expressions and bind them.
+ val inputAttrs = inputSchema.map(_.withNullability(true))
+ val numInputAttributes = inputAttrs.size
+ val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
+ case e: OffsetWindowFunction =>
if (e.default == null || e.default.foldable && e.default.eval() == null) {
- // Without default value.
- input
+ // The default value is null.
+ Literal.create(null, e.dataType)
} else {
- // With default value.
+ // The default value is an expression.
val default = BindReferences.bindReference(e.default, inputAttrs).transform {
// Shift the input reference to its default version.
case BoundReference(o, dataType, nullable) =>
BoundReference(o + numInputAttributes, dataType, nullable)
}
- org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil)
+ default
}
case e =>
BindReferences.bindReference(e, inputAttrs)
@@ -625,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame(
if (inputIndex >= 0 && inputIndex < input.size) {
val r = input.next()
join(r, current)
+ projection(join)
} else {
join(emptyRow, current)
+ // Use default values since the offset row does not exist.
+ fillDefaultValue(join)
}
- projection(join)
inputIndex += 1
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index 77e97dff8c..d3cfa953a3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -15,12 +15,10 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.execution
+package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
-
+import org.apache.spark.sql.test.SharedSQLContext
case class WindowData(month: Int, area: String, product: Int)
@@ -28,8 +26,9 @@ case class WindowData(month: Int, area: String, product: Int)
/**
* Test suite for SQL window functions.
*/
-class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
- import spark.implicits._
+class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
+
+ import testImplicits._
test("window function: udaf with aggregate expression") {
val data = Seq(
@@ -357,14 +356,59 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi
}
test("SPARK-7595: Window will cause resolve failed with self join") {
- sql("SELECT * FROM src") // Force loading of src table.
-
checkAnswer(sql(
"""
|with
- | v1 as (select key, count(value) over (partition by key) cnt_val from src),
+ | v0 as (select 0 as key, 1 as value),
+ | v1 as (select key, count(value) over (partition by key) cnt_val from v0),
| v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key)
- | select * from v2 order by key limit 1
- """.stripMargin), Row(0, 3))
+ | select key, cnt_val from v2 order by key limit 1
+ """.stripMargin), Row(0, 1))
+ }
+
+ test("SPARK-16633: lead/lag should return the default value if the offset row does not exist") {
+ checkAnswer(sql(
+ """
+ |SELECT
+ | lag(123, 100, 321) OVER (ORDER BY id) as lag,
+ | lead(123, 100, 321) OVER (ORDER BY id) as lead
+ |FROM (SELECT 1 as id) tmp
+ """.stripMargin),
+ Row(321, 321))
+
+ checkAnswer(sql(
+ """
+ |SELECT
+ | lag(123, 100, a) OVER (ORDER BY id) as lag,
+ | lead(123, 100, a) OVER (ORDER BY id) as lead
+ |FROM (SELECT 1 as id, 2 as a) tmp
+ """.stripMargin),
+ Row(2, 2))
+ }
+
+ test("lead/lag should respect null values") {
+ checkAnswer(sql(
+ """
+ |SELECT
+ | b,
+ | lag(a, 1, 321) OVER (ORDER BY b) as lag,
+ | lead(a, 1, 321) OVER (ORDER BY b) as lead
+ |FROM (SELECT cast(null as int) as a, 1 as b
+ | UNION ALL
+ | select cast(null as int) as id, 2 as b) tmp
+ """.stripMargin),
+ Row(1, 321, null) :: Row(2, null, 321) :: Nil)
+
+ checkAnswer(sql(
+ """
+ |SELECT
+ | b,
+ | lag(a, 1, c) OVER (ORDER BY b) as lag,
+ | lead(a, 1, c) OVER (ORDER BY b) as lead
+ |FROM (SELECT cast(null as int) as a, 1 as b, 3 as c
+ | UNION ALL
+ | select cast(null as int) as id, 2 as b, 4 as c) tmp
+ """.stripMargin),
+ Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}
}