aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-05-07 11:46:49 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-07 11:46:49 -0700
commit5784c8d95561dce432a85401e1510776fdf723a8 (patch)
tree1c7619266932086c717e2c08a5df3916f8f2b991 /sql
parent1712a7c7057bf6dd5da8aea1d7fbecdf96ea4b32 (diff)
downloadspark-5784c8d95561dce432a85401e1510776fdf723a8.tar.gz
spark-5784c8d95561dce432a85401e1510776fdf723a8.tar.bz2
spark-5784c8d95561dce432a85401e1510776fdf723a8.zip
[SPARK-1442] [SQL] [FOLLOW-UP] Address minor comments in Window Function PR (#5604).
Address marmbrus and scwf's comments in #5604. Author: Yin Huai <yhuai@databricks.com> Closes #5945 from yhuai/windowFollowup and squashes the following commits: 0ef879d [Yin Huai] Add collectFirst to TreeNode. 2373968 [Yin Huai] wip 4a16df9 [Yin Huai] Address minor comments for [SPARK-1442].
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala50
3 files changed, 68 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7b543b6c2a..7e46ad851c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -638,11 +638,10 @@ class Analyzer(
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
// First, we group window expressions based on their Window Spec.
val groupedWindowExpression = windowExpressions.groupBy { expr =>
- val windowExpression = expr.find {
- case window: WindowExpression => true
- case other => false
- }.map(_.asInstanceOf[WindowExpression].windowSpec)
- windowExpression.getOrElse(
+ val windowSpec = expr.collectFirst {
+ case window: WindowExpression => window.windowSpec
+ }
+ windowSpec.getOrElse(
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
}.toSeq
@@ -685,7 +684,7 @@ class Analyzer(
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
- !a.expressions.exists(!_.resolved) =>
+ a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
@@ -702,7 +701,7 @@ class Analyzer(
// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
- !a.expressions.exists(!_.resolved) =>
+ a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 4b93f7d31b..bc2ad34523 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -131,6 +131,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}
/**
+ * Finds and returns the first [[TreeNode]] of the tree for which the given partial function
+ * is defined (pre-order), and applies the partial function to it.
+ */
+ def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
+ val lifted = pf.lift
+ lifted(this).orElse {
+ children.foldLeft(None: Option[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
+ }
+ }
+
+ /**
* Returns a copy of this node where `f` has been applied to all the nodes children.
*/
def mapChildren(f: BaseType => BaseType): this.type = {
@@ -160,7 +171,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val remainingNewChildren = newChildren.toBuffer
val remainingOldChildren = children.toBuffer
val newArgs = productIterator.map {
- // This rule is used to handle children is a input argument.
+ // Handle Seq[TreeNode] in TreeNode parameters.
case s: Seq[_] => s.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = remainingNewChildren.remove(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 786ddba403..3d10dab5ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -172,4 +172,54 @@ class TreeNodeSuite extends FunSuite {
expected = None
assert(expected === actual)
}
+
+ test("collectFirst") {
+ val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+
+ // Collect the top node.
+ {
+ val actual = expression.collectFirst {
+ case add: Add => add
+ }
+ val expected =
+ Some(Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))))
+ assert(expected === actual)
+ }
+
+ // Collect the first children.
+ {
+ val actual = expression.collectFirst {
+ case l @ Literal(1, IntegerType) => l
+ }
+ val expected = Some(Literal(1))
+ assert(expected === actual)
+ }
+
+ // Collect an internal node (Subtract).
+ {
+ val actual = expression.collectFirst {
+ case sub: Subtract => sub
+ }
+ val expected = Some(Subtract(Literal(3), Literal(4)))
+ assert(expected === actual)
+ }
+
+ // Collect a leaf node.
+ {
+ val actual = expression.collectFirst {
+ case l @ Literal(3, IntegerType) => l
+ }
+ val expected = Some(Literal(3))
+ assert(expected === actual)
+ }
+
+ // Collect nothing.
+ {
+ val actual = expression.collectFirst {
+ case l @ Literal(100, IntegerType) => l
+ }
+ val expected = None
+ assert(expected === actual)
+ }
+ }
}