aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-06-01 21:40:17 -0700
committerReynold Xin <rxin@databricks.com>2015-06-01 21:40:17 -0700
commite797dba58e8cafdd30683dd1e0263f00ce30ccc0 (patch)
tree5a8e92b38cdd756f35f11a28e2dc87b7272fdd01 /sql
parent7f74bb3bc6d29c53e67af6b6eec336f2d083322a (diff)
downloadspark-e797dba58e8cafdd30683dd1e0263f00ce30ccc0.tar.gz
spark-e797dba58e8cafdd30683dd1e0263f00ce30ccc0.tar.bz2
spark-e797dba58e8cafdd30683dd1e0263f00ce30ccc0.zip
[SPARK-7965] [SPARK-7972] [SQL] Handle expressions containing multiple window expressions and make parser match window frames in case insensitive way
JIRAs: https://issues.apache.org/jira/browse/SPARK-7965 https://issues.apache.org/jira/browse/SPARK-7972 Author: Yin Huai <yhuai@databricks.com> Closes #6524 from yhuai/7965-7972 and squashes the following commits: c12c79c [Yin Huai] Add doc for returned value. de64328 [Yin Huai] Address rxin's comments. fc9b1ad [Yin Huai] wip 2996da4 [Yin Huai] scala style 20b65b7 [Yin Huai] Handle expressions containing multiple window expressions. 9568b21 [Yin Huai] case insensitive matches 41f633d [Yin Huai] Failed test case.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala108
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala22
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala36
3 files changed, 134 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index df37889eed..8e9fec7070 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -633,10 +633,10 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
- def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
+ private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
- def hasWindowFunction(expr: NamedExpression): Boolean = {
+ private def hasWindowFunction(expr: NamedExpression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
@@ -644,14 +644,24 @@ class Analyzer(
}
/**
- * From a Seq of [[NamedExpression]]s, extract window expressions and
- * other regular expressions.
+ * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
+ * other regular expressions that do not contain any window expression. For example, for
+ * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract
+ * `col1`, `col2 + col3`, `col4`, and `col5` out and replace them appearances in
+ * the window expression as attribute references. So, the first returned value will be
+ * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be
+ * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2].
+ *
+ * @return (seq of expressions containing at lease one window expressions,
+ * seq of non-window expressions)
*/
- def extract(
+ private def extract(
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
- // First, we simple partition the input expressions to two part, one having
- // WindowExpressions and another one without WindowExpressions.
- val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
+ // First, we partition the input expressions to two part. For the first part,
+ // every expression in it contain at least one WindowExpression.
+ // Expressions in the second part do not have any WindowExpression.
+ val (expressionsWithWindowFunctions, regularExpressions) =
+ expressions.partition(hasWindowFunction)
// Then, we need to extract those regular expressions used in the WindowExpression.
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
@@ -660,8 +670,8 @@ class Analyzer(
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
- // If a named expression is not in regularExpressions, add extract it and replace it
- // with an AttributeReference.
+ // If a named expression is not in regularExpressions, add it to
+ // extractedExprBuffer and replace it with an AttributeReference.
val missingExpr =
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
if (missingExpr.nonEmpty) {
@@ -678,8 +688,9 @@ class Analyzer(
withName.toAttribute
}
- // Now, we extract expressions from windowExpressions by using extractExpr.
- val newWindowExpressions = windowExpressions.map {
+ // Now, we extract regular expressions from expressionsWithWindowFunctions
+ // by using extractExpr.
+ val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
@@ -705,37 +716,80 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}
- (newWindowExpressions, regularExpressions ++ extractedExprBuffer)
- }
+ (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
+ } // end of extract
/**
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
*/
- def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
- // First, we group window expressions based on their Window Spec.
- val groupedWindowExpression = windowExpressions.groupBy { expr =>
- val windowSpec = expr.collectFirst {
+ private def addWindow(
+ expressionsWithWindowFunctions: Seq[NamedExpression],
+ child: LogicalPlan): LogicalPlan = {
+ // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
+ // and put those extracted WindowExpressions to extractedWindowExprBuffer.
+ // This step is needed because it is possible that an expression contains multiple
+ // WindowExpressions with different Window Specs.
+ // After extracting WindowExpressions, we need to construct a project list to generate
+ // expressionsWithWindowFunctions based on extractedWindowExprBuffer.
+ // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
+ // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
+ // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
+ // Then, the projectList will be [_we0/_we1].
+ val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
+ val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
+ // We need to use transformDown because we want to trigger
+ // "case alias @ Alias(window: WindowExpression, _)" first.
+ _.transformDown {
+ case alias @ Alias(window: WindowExpression, _) =>
+ // If a WindowExpression has an assigned alias, just use it.
+ extractedWindowExprBuffer += alias
+ alias.toAttribute
+ case window: WindowExpression =>
+ // If there is no alias assigned to the WindowExpressions. We create an
+ // internal column.
+ val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
+ extractedWindowExprBuffer += withName
+ withName.toAttribute
+ }.asInstanceOf[NamedExpression]
+ }
+
+ // Second, we group extractedWindowExprBuffer based on their Window Spec.
+ val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
+ val distinctWindowSpec = expr.collect {
case window: WindowExpression => window.windowSpec
+ }.distinct
+
+ // We do a final check and see if we only have a single Window Spec defined in an
+ // expressions.
+ if (distinctWindowSpec.length == 0 ) {
+ failAnalysis(s"$expr does not have any WindowExpression.")
+ } else if (distinctWindowSpec.length > 1) {
+ // newExpressionsWithWindowFunctions only have expressions with a single
+ // WindowExpression. If we reach here, we have a bug.
+ failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
+ s"Please file a bug report with this error message, stack trace, and the query.")
+ } else {
+ distinctWindowSpec.head
}
- windowSpec.getOrElse(
- failAnalysis(s"$windowExpressions does not have any WindowExpression."))
}.toSeq
- // For every Window Spec, we add a Window operator and set currentChild as the child of it.
+ // Third, for every Window Spec, we add a Window operator and set currentChild as the
+ // child of it.
var currentChild = child
var i = 0
- while (i < groupedWindowExpression.size) {
- val (windowSpec, windowExpressions) = groupedWindowExpression(i)
+ while (i < groupedWindowExpressions.size) {
+ val (windowSpec, windowExpressions) = groupedWindowExpressions(i)
// Set currentChild to the newly created Window operator.
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
- // Move to next WindowExpression.
+ // Move to next Window Spec.
i += 1
}
- // We return the top operator.
- currentChild
- }
+ // Finally, we create a Project to output currentChild's output
+ // newExpressionsWithWindowFunctions.
+ Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild)
+ } // end of addWindow
// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 253bf11252..a5ca3613c5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1561,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
""".stripMargin)
}
+ /* Case insensitive matches for Window Specification */
+ val PRECEDING = "(?i)preceding".r
+ val FOLLOWING = "(?i)following".r
+ val CURRENT = "(?i)current".r
def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
case Token(windowName, Nil) :: Nil =>
// Refer to a window spec defined in the window clause.
@@ -1614,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
} else {
val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
def nodeToBoundary(node: Node): FrameBoundary = node match {
- case Token("preceding", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt)
- case Token("following", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt)
- case Token("current", Nil) => CurrentRow
+ case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedPreceding
+ } else {
+ ValuePreceding(count.toInt)
+ }
+ case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedFollowing
+ } else {
+ ValueFollowing(count.toInt)
+ }
+ case Token(CURRENT(), Nil) => CurrentRow
case _ =>
throw new NotImplementedError(
s"""No parse rules for the Window Frame Boundary based on Node ${node.getName}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 27863a6014..aba3becb1b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -780,6 +780,42 @@ class SQLQuerySuite extends QueryTest {
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+ test("window function: multiple window expressions in a single expression") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.registerTempTable("nums")
+
+ val expected =
+ Row(1, 1, 1, 55, 1, 57) ::
+ Row(0, 2, 3, 55, 2, 60) ::
+ Row(1, 3, 6, 55, 4, 65) ::
+ Row(0, 4, 10, 55, 6, 71) ::
+ Row(1, 5, 15, 55, 9, 79) ::
+ Row(0, 6, 21, 55, 12, 88) ::
+ Row(1, 7, 28, 55, 16, 99) ::
+ Row(0, 8, 36, 55, 20, 111) ::
+ Row(1, 9, 45, 55, 25, 125) ::
+ Row(0, 10, 55, 55, 30, 140) :: Nil
+
+ val actual = sql(
+ """
+ |SELECT
+ | y,
+ | x,
+ | sum(x) OVER w1 AS running_sum,
+ | sum(x) OVER w2 AS total_sum,
+ | sum(x) OVER w3 AS running_sum_per_y,
+ | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2
+ |FROM nums
+ |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW),
+ | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING),
+ | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+ """.stripMargin)
+
+ checkAnswer(actual, expected)
+
+ dropTempTable("nums")
+ }
+
test("test case key when") {
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
checkAnswer(