aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-05-06 10:43:00 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-06 10:43:00 -0700
commitf2c47082c3412a4cf8cbabe12585147c5ec3ea40 (patch)
treee2c9b4112724d2ab79122486f1e2d98375081329 /sql/catalyst
parentc3eb441f5487c9b6476e1d6e2a2d852dcc43b986 (diff)
downloadspark-f2c47082c3412a4cf8cbabe12585147c5ec3ea40.tar.gz
spark-f2c47082c3412a4cf8cbabe12585147c5ec3ea40.tar.bz2
spark-f2c47082c3412a4cf8cbabe12585147c5ec3ea40.zip
[SPARK-1442] [SQL] Window Function Support for Spark SQL
Adding more information about the implementation... This PR is adding the support of window functions to Spark SQL (specifically OVER and WINDOW clause). For every expression having a OVER clause, we use a WindowExpression as the container of a WindowFunction and the corresponding WindowSpecDefinition (the definition of a window frame, i.e. partition specification, order specification, and frame specification appearing in a OVER clause). # Implementation # The high level work flow of the implementation is described as follows. * Query parsing: In the query parse process, all WindowExpressions are originally placed in the projectList of a Project operator or the aggregateExpressions of an Aggregate operator. It makes our changes to simple and keep all of parsing rules for window functions at a single place (nodesToWindowSpecification). For the WINDOWclause in a query, we use a WithWindowDefinition as the container as the mapping from the name of a window specification to a WindowSpecDefinition. This changes is similar with our common table expression support. * Analysis: The query analysis process has three steps for window functions. * Resolve all WindowSpecReferences by replacing them with WindowSpecReferences according to the mapping table stored in the node of WithWindowDefinition. * Resolve WindowFunctions in the projectList of a Project operator or the aggregateExpressions of an Aggregate operator. For this PR, we use Hive's functions for window functions because we will have a major refactoring of our internal UDAFs and it is better to switch our UDAFs after that refactoring work. * Once we have resolved all WindowFunctions, we will use ResolveWindowFunction to extract WindowExpressions from projectList and aggregateExpressions and then create a Window operator for every distinct WindowSpecDefinition. With this choice, at the execution time, we can rely on the Exchange operator to do all of work on reorganizing the table and we do not need to worry about it in the physical Window operator. An example analyzed plan is shown as follows ``` sql(""" SELECT year, country, product, sales, avg(sales) over(partition by product) avg_product, sum(sales) over(partition by country) sum_country FROM sales ORDER BY year, country, product """).explain(true) == Analyzed Logical Plan == Sort [year#34 ASC,country#35 ASC,product#36 ASC], true Project [year#34,country#35,product#36,sales#37,avg_product#27,sum_country#28] Window [year#34,country#35,product#36,sales#37,avg_product#27], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum(sales#37) WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sum_country#28], WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Window [year#34,country#35,product#36,sales#37], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage(sales#37) WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS avg_product#27], WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Project [year#34,country#35,product#36,sales#37] MetastoreRelation default, sales, None ``` * Query planning: In the process of query planning, we simple generate the physical Window operator based on the logical Window operator. Then, to prepare the executedPlan, the EnsureRequirements rule will add Exchange and Sort operators if necessary. The EnsureRequirements rule will analyze the data properties and try to not add unnecessary shuffle and sort. The physical plan for the above example query is shown below. ``` == Physical Plan == Sort [year#34 ASC,country#35 ASC,product#36 ASC], true Exchange (RangePartitioning [year#34 ASC,country#35 ASC,product#36 ASC], 200), [] Window [year#34,country#35,product#36,sales#37,avg_product#27], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum(sales#37) WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sum_country#28], WindowSpecDefinition [country#35], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Exchange (HashPartitioning [country#35], 200), [country#35 ASC] Window [year#34,country#35,product#36,sales#37], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage(sales#37) WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS avg_product#27], WindowSpecDefinition [product#36], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Exchange (HashPartitioning [product#36], 200), [product#36 ASC] HiveTableScan [year#34,country#35,product#36,sales#37], (MetastoreRelation default, sales, None), None ``` * Execution time: At execution time, a physical Window operator buffers all rows in a partition specified in the partition spec of a OVER clause. If necessary, it also maintains a sliding window frame. The current implementation tries to buffer the input parameters of a window function according to the window frame to avoid evaluating a row multiple times. # Future work # Here are three improvements that are not hard to add: * Taking advantage of the window frame specification to reduce the number of rows buffered in the physical Window operator. For some cases, we only need to buffer the rows appearing in the sliding window. But for other cases, we will not be able to reduce the number of rows buffered (e.g. ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING). * When aRAGEN frame is used, for <value> PRECEDING and <value> FOLLOWING, it will be great if the <value> part is an expression (we can start with Literal). So, when the data type of ORDER BY expression is a FractionalType, we can support FractionalType as the type <value> (<value> still needs to be evaluated as a positive value). * When aRAGEN frame is used, we need to support DateType and TimestampType as the data type of the expression appearing in the order specification. Then, the <value> part of <value> PRECEDING and <value> FOLLOWING can support interval types (once we support them). This is a joint work with guowei2 and yhuai Thanks hbutani hvanhovell for his comments Thanks scwf for his comments and unit tests Author: Yin Huai <yhuai@databricks.com> Closes #5604 from guowei2/windowImplement and squashes the following commits: 76fe1c8 [Yin Huai] Implementation. aa2b0ae [Yin Huai] Tests.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala200
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala94
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala340
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala44
7 files changed, 733 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 5e42b409dc..7b543b6c2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
@@ -61,6 +63,7 @@ class Analyzer(
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
+ ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
@@ -529,6 +532,203 @@ class Analyzer(
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}
+
+ /**
+ * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
+ * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
+ * operators for every distinct [[WindowSpecDefinition]].
+ *
+ * This rule handles three cases:
+ * - A [[Project]] having [[WindowExpression]]s in its projectList;
+ * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions.
+ * - An [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING
+ * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions.
+ * Note: If there is a GROUP BY clause in the query, aggregations and corresponding
+ * filters (expressions in the HAVING clause) should be evaluated before any
+ * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be
+ * evaluated after all [[WindowExpression]]s.
+ *
+ * For every case, the transformation works as follows:
+ * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
+ * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
+ * all regular expressions.
+ * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
+ * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
+ * it into the plan tree.
+ */
+ object ExtractWindowExpressions extends Rule[LogicalPlan] {
+ def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
+ projectList.exists(hasWindowFunction)
+
+ def hasWindowFunction(expr: NamedExpression): Boolean = {
+ expr.find {
+ case window: WindowExpression => true
+ case _ => false
+ }.isDefined
+ }
+
+ /**
+ * From a Seq of [[NamedExpression]]s, extract window expressions and
+ * other regular expressions.
+ */
+ def extract(
+ expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
+ // First, we simple partition the input expressions to two part, one having
+ // WindowExpressions and another one without WindowExpressions.
+ val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
+
+ // Then, we need to extract those regular expressions used in the WindowExpression.
+ // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
+ // we need to make sure that col1 to col5 are all projected from the child of the Window
+ // operator.
+ val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
+ def extractExpr(expr: Expression): Expression = expr match {
+ case ne: NamedExpression =>
+ // If a named expression is not in regularExpressions, add extract it and replace it
+ // with an AttributeReference.
+ val missingExpr =
+ AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
+ if (missingExpr.nonEmpty) {
+ extractedExprBuffer += ne
+ }
+ ne.toAttribute
+ case e: Expression if e.foldable =>
+ e // No need to create an attribute reference if it will be evaluated as a Literal.
+ case e: Expression =>
+ // For other expressions, we extract it and replace it with an AttributeReference (with
+ // an interal column name, e.g. "_w0").
+ val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
+ extractedExprBuffer += withName
+ withName.toAttribute
+ }
+
+ // Now, we extract expressions from windowExpressions by using extractExpr.
+ val newWindowExpressions = windowExpressions.map {
+ _.transform {
+ // Extracts children expressions of a WindowFunction (input parameters of
+ // a WindowFunction).
+ case wf : WindowFunction =>
+ val newChildren = wf.children.map(extractExpr(_))
+ wf.withNewChildren(newChildren)
+
+ // Extracts expressions from the partition spec and order spec.
+ case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
+ val newPartitionSpec = partitionSpec.map(extractExpr(_))
+ val newOrderSpec = orderSpec.map { so =>
+ val newChild = extractExpr(so.child)
+ so.copy(child = newChild)
+ }
+ wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
+
+ // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
+ // we need to extract SUM(x).
+ case agg: AggregateExpression =>
+ val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
+ extractedExprBuffer += withName
+ withName.toAttribute
+ }.asInstanceOf[NamedExpression]
+ }
+
+ (newWindowExpressions, regularExpressions ++ extractedExprBuffer)
+ }
+
+ /**
+ * Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
+ */
+ def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
+ // First, we group window expressions based on their Window Spec.
+ val groupedWindowExpression = windowExpressions.groupBy { expr =>
+ val windowExpression = expr.find {
+ case window: WindowExpression => true
+ case other => false
+ }.map(_.asInstanceOf[WindowExpression].windowSpec)
+ windowExpression.getOrElse(
+ failAnalysis(s"$windowExpressions does not have any WindowExpression."))
+ }.toSeq
+
+ // For every Window Spec, we add a Window operator and set currentChild as the child of it.
+ var currentChild = child
+ var i = 0
+ while (i < groupedWindowExpression.size) {
+ val (windowSpec, windowExpressions) = groupedWindowExpression(i)
+ // Set currentChild to the newly created Window operator.
+ currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
+
+ // Move to next WindowExpression.
+ i += 1
+ }
+
+ // We return the top operator.
+ currentChild
+ }
+
+ // We have to use transformDown at here to make sure the rule of
+ // "Aggregate with Having clause" will be triggered.
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ // Lookup WindowSpecDefinitions. This rule works with unresolved children.
+ case WithWindowDefinition(windowDefinitions, child) =>
+ child.transform {
+ case plan => plan.transformExpressions {
+ case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
+ val errorMessage =
+ s"Window specification $windowName is not defined in the WINDOW clause."
+ val windowSpecDefinition =
+ windowDefinitions
+ .get(windowName)
+ .getOrElse(failAnalysis(errorMessage))
+ WindowExpression(c, windowSpecDefinition)
+ }
+ }
+
+ // Aggregate with Having clause. This rule works with an unresolved Aggregate because
+ // a resolved Aggregate will not have Window Functions.
+ case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
+ if child.resolved &&
+ hasWindowFunction(aggregateExprs) &&
+ !a.expressions.exists(!_.resolved) =>
+ val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
+ // Create an Aggregate operator to evaluate aggregation functions.
+ val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
+ // Add a Filter operator for conditions in the Having clause.
+ val withFilter = Filter(condition, withAggregate)
+ val withWindow = addWindow(windowExpressions, withFilter)
+
+ // Finally, generate output columns according to the original projectList.
+ val finalProjectList = aggregateExprs.map (_.toAttribute)
+ Project(finalProjectList, withWindow)
+
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ // Aggregate without Having clause.
+ case a @ Aggregate(groupingExprs, aggregateExprs, child)
+ if hasWindowFunction(aggregateExprs) &&
+ !a.expressions.exists(!_.resolved) =>
+ val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
+ // Create an Aggregate operator to evaluate aggregation functions.
+ val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
+ // Add Window operators.
+ val withWindow = addWindow(windowExpressions, withAggregate)
+
+ // Finally, generate output columns according to the original projectList.
+ val finalProjectList = aggregateExprs.map (_.toAttribute)
+ Project(finalProjectList, withWindow)
+
+ // We only extract Window Expressions after all expressions of the Project
+ // have been resolved.
+ case p @ Project(projectList, child)
+ if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
+ val (windowExpressions, regularExpressions) = extract(projectList)
+ // We add a project to get all needed expressions for window expressions from the child
+ // of the original Project operator.
+ val withProject = Project(regularExpressions, child)
+ // Add Window operators.
+ val withWindow = addWindow(windowExpressions, withProject)
+
+ // Finally, generate output columns according to the original projectList.
+ val finalProjectList = projectList.map (_.toAttribute)
+ Project(finalProjectList, withWindow)
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 2381689e17..c8288c6767 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -70,6 +70,11 @@ trait CheckAnalysis {
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")
+
+ case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty =>
+ // The window spec is not valid.
+ val reason = windowSpec.validate.get
+ failAnalysis(s"Window specification $windowSpec is not valid because $reason")
}
operator match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index c2866cd955..8cae548279 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -548,3 +548,97 @@ class JoinedRow5 extends Row {
}
}
}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow6 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
+
+ override def length: Int = row1.length + row2.length
+
+ override def apply(i: Int): Any =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
+
+ override def isNullAt(i: Int): Boolean =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
+
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
+
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
+
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
+
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
+
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
+
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
+
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
+
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
+
+ override def getAs[T](i: Int): T =
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
+
+ override def copy(): Row = {
+ val totalSize = row1.length + row2.length
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString: String = {
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
new file mode 100644
index 0000000000..099d67ca7f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{NumericType, DataType}
+
+/**
+ * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for
+ * Window Functions.
+ */
+sealed trait WindowSpec
+
+/**
+ * The specification for a window function.
+ * @param partitionSpec It defines the way that input rows are partitioned.
+ * @param orderSpec It defines the ordering of rows in a partition.
+ * @param frameSpecification It defines the window frame in a partition.
+ */
+case class WindowSpecDefinition(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ frameSpecification: WindowFrame) extends Expression with WindowSpec {
+
+ def validate: Option[String] = frameSpecification match {
+ case UnspecifiedFrame =>
+ Some("Found a UnspecifiedFrame. It should be converted to a SpecifiedWindowFrame " +
+ "during analysis. Please file a bug report.")
+ case frame: SpecifiedWindowFrame => frame.validate.orElse {
+ def checkValueBasedBoundaryForRangeFrame(): Option[String] = {
+ if (orderSpec.length > 1) {
+ // It is not allowed to have a value-based PRECEDING and FOLLOWING
+ // as the boundary of a Range Window Frame.
+ Some("This Range Window Frame only accepts at most one ORDER BY expression.")
+ } else if (orderSpec.nonEmpty && !orderSpec.head.dataType.isInstanceOf[NumericType]) {
+ Some("The data type of the expression in the ORDER BY clause should be a numeric type.")
+ } else {
+ None
+ }
+ }
+
+ (frame.frameType, frame.frameStart, frame.frameEnd) match {
+ case (RangeFrame, vp: ValuePreceding, _) => checkValueBasedBoundaryForRangeFrame()
+ case (RangeFrame, vf: ValueFollowing, _) => checkValueBasedBoundaryForRangeFrame()
+ case (RangeFrame, _, vp: ValuePreceding) => checkValueBasedBoundaryForRangeFrame()
+ case (RangeFrame, _, vf: ValueFollowing) => checkValueBasedBoundaryForRangeFrame()
+ case (_, _, _) => None
+ }
+ }
+ }
+
+ type EvaluatedType = Any
+
+ override def children: Seq[Expression] = partitionSpec ++ orderSpec
+
+ override lazy val resolved: Boolean =
+ childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame]
+
+
+ override def toString: String = simpleString
+
+ override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException
+ override def nullable: Boolean = true
+ override def foldable: Boolean = false
+ override def dataType: DataType = throw new UnsupportedOperationException
+}
+
+/**
+ * A Window specification reference that refers to the [[WindowSpecDefinition]] defined
+ * under the name `name`.
+ */
+case class WindowSpecReference(name: String) extends WindowSpec
+
+/**
+ * The trait used to represent the type of a Window Frame.
+ */
+sealed trait FrameType
+
+/**
+ * RowFrame treats rows in a partition individually. When a [[ValuePreceding]]
+ * or a [[ValueFollowing]] is used as its [[FrameBoundary]], the value is considered
+ * as a physical offset.
+ * For example, `ROW BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a 3-row frame,
+ * from the row precedes the current row to the row follows the current row.
+ */
+case object RowFrame extends FrameType
+
+/**
+ * RangeFrame treats rows in a partition as groups of peers.
+ * All rows having the same `ORDER BY` ordering are considered as peers.
+ * When a [[ValuePreceding]] or a [[ValueFollowing]] is used as its [[FrameBoundary]],
+ * the value is considered as a logical offset.
+ * For example, assuming the value of the current row's `ORDER BY` expression `expr` is `v`,
+ * `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a frame containing rows whose values
+ * `expr` are in the range of [v-1, v+1].
+ *
+ * If `ORDER BY` clause is not defined, all rows in the partition is considered as peers
+ * of the current row.
+ */
+case object RangeFrame extends FrameType
+
+/**
+ * The trait used to represent the type of a Window Frame Boundary.
+ */
+sealed trait FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean
+}
+
+/** UNBOUNDED PRECEDING boundary. */
+case object UnboundedPreceding extends FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean = other match {
+ case UnboundedPreceding => true
+ case vp: ValuePreceding => true
+ case CurrentRow => true
+ case vf: ValueFollowing => true
+ case UnboundedFollowing => true
+ }
+
+ override def toString: String = "UNBOUNDED PRECEDING"
+}
+
+/** <value> PRECEDING boundary. */
+case class ValuePreceding(value: Int) extends FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean = other match {
+ case UnboundedPreceding => false
+ case ValuePreceding(anotherValue) => value >= anotherValue
+ case CurrentRow => true
+ case vf: ValueFollowing => true
+ case UnboundedFollowing => true
+ }
+
+ override def toString: String = s"$value PRECEDING"
+}
+
+/** CURRENT ROW boundary. */
+case object CurrentRow extends FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean = other match {
+ case UnboundedPreceding => false
+ case vp: ValuePreceding => false
+ case CurrentRow => true
+ case vf: ValueFollowing => true
+ case UnboundedFollowing => true
+ }
+
+ override def toString: String = "CURRENT ROW"
+}
+
+/** <value> FOLLOWING boundary. */
+case class ValueFollowing(value: Int) extends FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean = other match {
+ case UnboundedPreceding => false
+ case vp: ValuePreceding => false
+ case CurrentRow => false
+ case ValueFollowing(anotherValue) => value <= anotherValue
+ case UnboundedFollowing => true
+ }
+
+ override def toString: String = s"$value FOLLOWING"
+}
+
+/** UNBOUNDED FOLLOWING boundary. */
+case object UnboundedFollowing extends FrameBoundary {
+ def notFollows(other: FrameBoundary): Boolean = other match {
+ case UnboundedPreceding => false
+ case vp: ValuePreceding => false
+ case CurrentRow => false
+ case vf: ValueFollowing => false
+ case UnboundedFollowing => true
+ }
+
+ override def toString: String = "UNBOUNDED FOLLOWING"
+}
+
+/**
+ * The trait used to represent the a Window Frame.
+ */
+sealed trait WindowFrame
+
+/** Used as a place holder when a frame specification is not defined. */
+case object UnspecifiedFrame extends WindowFrame
+
+/** A specified Window Frame. */
+case class SpecifiedWindowFrame(
+ frameType: FrameType,
+ frameStart: FrameBoundary,
+ frameEnd: FrameBoundary) extends WindowFrame {
+
+ /** If this WindowFrame is valid or not. */
+ def validate: Option[String] = (frameType, frameStart, frameEnd) match {
+ case (_, UnboundedFollowing, _) =>
+ Some(s"$UnboundedFollowing is not allowed as the start of a Window Frame.")
+ case (_, _, UnboundedPreceding) =>
+ Some(s"$UnboundedPreceding is not allowed as the end of a Window Frame.")
+ // case (RowFrame, start, end) => ??? RowFrame specific rule
+ // case (RangeFrame, start, end) => ??? RangeFrame specific rule
+ case (_, start, end) =>
+ if (start.notFollows(end)) {
+ None
+ } else {
+ val reason =
+ s"The end of this Window Frame $end is smaller than the start of " +
+ s"this Window Frame $start."
+ Some(reason)
+ }
+ }
+
+ override def toString: String = frameType match {
+ case RowFrame => s"ROWS BETWEEN $frameStart AND $frameEnd"
+ case RangeFrame => s"RANGE BETWEEN $frameStart AND $frameEnd"
+ }
+}
+
+object SpecifiedWindowFrame {
+ /**
+ *
+ * @param hasOrderSpecification If the window spec has order by expressions.
+ * @param acceptWindowFrame If the window function accepts user-specified frame.
+ * @return
+ */
+ def defaultWindowFrame(
+ hasOrderSpecification: Boolean,
+ acceptWindowFrame: Boolean): SpecifiedWindowFrame = {
+ if (hasOrderSpecification && acceptWindowFrame) {
+ // If order spec is defined and the window function supports user specified window frames,
+ // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
+ SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
+ } else {
+ // Otherwise, the default frame is
+ // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING.
+ SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
+ }
+ }
+}
+
+/**
+ * Every window function needs to maintain a output buffer for its output.
+ * It should expect that for a n-row window frame, it will be called n times
+ * to retrieve value corresponding with these n rows.
+ */
+trait WindowFunction extends Expression {
+ self: Product =>
+
+ def init(): Unit
+
+ def reset(): Unit
+
+ def prepareInputParameters(input: Row): AnyRef
+
+ def update(input: AnyRef): Unit
+
+ def batchUpdate(inputs: Array[AnyRef]): Unit
+
+ def evaluate(): Unit
+
+ def get(index: Int): Any
+
+ def newInstance(): WindowFunction
+}
+
+case class UnresolvedWindowFunction(
+ name: String,
+ children: Seq[Expression])
+ extends Expression with WindowFunction {
+
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+
+ override def init(): Unit =
+ throw new UnresolvedException(this, "init")
+ override def reset(): Unit =
+ throw new UnresolvedException(this, "reset")
+ override def prepareInputParameters(input: Row): AnyRef =
+ throw new UnresolvedException(this, "prepareInputParameters")
+ override def update(input: AnyRef): Unit =
+ throw new UnresolvedException(this, "update")
+ override def batchUpdate(inputs: Array[AnyRef]): Unit =
+ throw new UnresolvedException(this, "batchUpdate")
+ override def evaluate(): Unit =
+ throw new UnresolvedException(this, "evaluate")
+ override def get(index: Int): Any =
+ throw new UnresolvedException(this, "get")
+ // Unresolved functions are transient at compile time and don't get evaluated during execution.
+ override def eval(input: Row = null): EvaluatedType =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
+ override def toString: String = s"'$name(${children.mkString(",")})"
+
+ override def newInstance(): WindowFunction =
+ throw new UnresolvedException(this, "newInstance")
+}
+
+case class UnresolvedWindowExpression(
+ child: UnresolvedWindowFunction,
+ windowSpec: WindowSpecReference) extends UnaryExpression {
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+
+ // Unresolved functions are transient at compile time and don't get evaluated during execution.
+ override def eval(input: Row = null): EvaluatedType =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+}
+
+case class WindowExpression(
+ windowFunction: WindowFunction,
+ windowSpec: WindowSpecDefinition) extends Expression {
+ override type EvaluatedType = Any
+
+ override def children: Seq[Expression] =
+ windowFunction :: windowSpec :: Nil
+
+ override def eval(input: Row): EvaluatedType =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
+ override def dataType: DataType = windowFunction.dataType
+ override def foldable: Boolean = windowFunction.foldable
+ override def nullable: Boolean = windowFunction.nullable
+
+ override def toString: String = s"$windowFunction $windowSpec"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21208c8a5c..ba0abb2df5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -25,13 +25,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override lazy val resolved: Boolean = {
- val containsAggregatesOrGenerators = projectList.exists ( _.collect {
+ val hasSpecialExpressions = projectList.exists ( _.collect {
case agg: AggregateExpression => agg
case generator: Generator => generator
+ case window: WindowExpression => window
}.nonEmpty
)
- !expressions.exists(!_.resolved) && childrenResolved && !containsAggregatesOrGenerators
+ !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}
}
@@ -170,6 +171,12 @@ case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends
override def output: Seq[Attribute] = child.output
}
+case class WithWindowDefinition(
+ windowDefinitions: Map[String, WindowSpecDefinition],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+}
+
case class WriteToFile(
path: String,
child: LogicalPlan) extends UnaryNode {
@@ -195,9 +202,28 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
+ override lazy val resolved: Boolean = {
+ val hasWindowExpressions = aggregateExpressions.exists ( _.collect {
+ case window: WindowExpression => window
+ }.nonEmpty
+ )
+
+ !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
+ }
+
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
+case class Window(
+ projectList: Seq[Attribute],
+ windowExpressions: Seq[NamedExpression],
+ windowSpec: WindowSpecDefinition,
+ child: LogicalPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] =
+ (projectList ++ windowExpressions).map(_.toAttribute)
+}
+
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 97502ed3af..4b93f7d31b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -72,6 +72,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}
/**
+ * Find the first [[TreeNode]] that satisfies the condition specified by `f`.
+ * The condition is recursively applied to this node and all of its children (pre-order).
+ */
+ def find(f: BaseType => Boolean): Option[BaseType] = f(this) match {
+ case true => Some(this)
+ case false => children.foldLeft(None: Option[BaseType]) { (l, r) => l.orElse(r.find(f)) }
+ }
+
+ /**
* Runs the given function on this node and then recursively on [[children]].
* @param f the function to be applied to each node in the tree.
*/
@@ -151,6 +160,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val remainingNewChildren = newChildren.toBuffer
val remainingOldChildren = children.toBuffer
val newArgs = productIterator.map {
+ // This rule is used to handle children is a input argument.
+ case s: Seq[_] => s.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = remainingNewChildren.remove(0)
+ val oldChild = remainingOldChildren.remove(0)
+ if (newChild fastEquals oldChild) {
+ oldChild
+ } else {
+ changed = true
+ newChild
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }
case arg: TreeNode[_] if children contains arg =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6b393327cc..786ddba403 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{StringType, NullType}
+import org.apache.spark.sql.types.{IntegerType, StringType, NullType}
case class Dummy(optKey: Option[Expression]) extends Expression {
def children: Seq[Expression] = optKey.toSeq
@@ -129,5 +129,47 @@ class TreeNodeSuite extends FunSuite {
assert(expected === actual)
}
+ test("find") {
+ val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+ // Find the top node.
+ var actual: Option[Expression] = expression.find {
+ case add: Add => true
+ case other => false
+ }
+ var expected: Option[Expression] =
+ Some(Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))))
+ assert(expected === actual)
+
+ // Find the first children.
+ actual = expression.find {
+ case Literal(1, IntegerType) => true
+ case other => false
+ }
+ expected = Some(Literal(1))
+ assert(expected === actual)
+ // Find an internal node (Subtract).
+ actual = expression.find {
+ case sub: Subtract => true
+ case other => false
+ }
+ expected = Some(Subtract(Literal(3), Literal(4)))
+ assert(expected === actual)
+
+ // Find a leaf node.
+ actual = expression.find {
+ case Literal(3, IntegerType) => true
+ case other => false
+ }
+ expected = Some(Literal(3))
+ assert(expected === actual)
+
+ // Find nothing.
+ actual = expression.find {
+ case Literal(100, IntegerType) => true
+ case other => false
+ }
+ expected = None
+ assert(expected === actual)
+ }
}