aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-12 16:38:28 +0800
committerCheng Lian <lian@databricks.com>2015-06-12 16:38:28 +0800
commitc19c78577a211eefe1112ebd4670a4ce7c3cc3be (patch)
tree47f04bc59d8ccffb756424a310f034240e48f2bc /sql
parente428b3a951377d47aa80d5f26d6bab979e72e8ab (diff)
downloadspark-c19c78577a211eefe1112ebd4670a4ce7c3cc3be.tar.gz
spark-c19c78577a211eefe1112ebd4670a4ce7c3cc3be.tar.bz2
spark-c19c78577a211eefe1112ebd4670a4ce7c3cc3be.zip
[SQL] [MINOR] correct semanticEquals logic
It's a follow up of https://github.com/apache/spark/pull/6173, for expressions like `Coalesce` that have a `Seq[Expression]`, when we do semantic equal check for it, we need to do semantic equal check for all of its children. Also we can just use `Seq[(Expression, NamedExpression)]` instead of `Map[Expression, NamedExpression]` as we only search it with `find`. chenghao-intel, I agree that we probably never knows `semanticEquals` in a general way, but I think we have done that in `TreeNode`, so we can use similar logic. Then we can handle something like `Coalesce(children: Seq[Expression])` correctly. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6261 from cloud-fan/tmp and squashes the following commits: 4daef88 [Wenchen Fan] address comments dd8fbd9 [Wenchen Fan] correct semanticEquals
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
4 files changed, 25 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 8c1e4d74f9..0b9f621fed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] {
* cosmetically (i.e. capitalization of names in attributes may be different).
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
+ def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
+ elements1.length == elements2.length && elements1.zip(elements2).forall {
+ case (e1: Expression, e2: Expression) => e1 semanticEquals e2
+ case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
+ case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
+ case (i1, i2) => i1 == i2
+ }
+ }
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
- elements1.length == elements2.length && elements1.zip(elements2).forall {
- case (e1: Expression, e2: Expression) => e1 semanticEquals e2
- case (i1, i2) => i1 == i2
- }
+ checkSemantic(elements1, elements2)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 1dd75a8846..3b6f8bfd9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -143,11 +143,11 @@ object PartialAggregation {
// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
- val namedGroupingExpressions: Map[Expression, NamedExpression] =
+ val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
- }.toMap
+ }
// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
@@ -160,17 +160,15 @@ object PartialAggregation {
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
- namedGroupingExpressions
- .find { case (k, v) => k semanticEquals trimmed }
- .map(_._2.toAttribute)
- .getOrElse(e)
+ namedGroupingExpressions.collectFirst {
+ case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
+ }.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
- val partialComputation =
- (namedGroupingExpressions.values ++
- partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+ val partialComputation = namedGroupingExpressions.map(_._2) ++
+ partialEvaluations.values.flatMap(_.partialEvaluations)
- val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
Some(
(namedGroupingAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index af3791734d..1c40a9209f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -214,18 +214,18 @@ case class GeneratedAggregate(
}.toMap
val namedGroups = groupingExpressions.zipWithIndex.map {
- case (ne: NamedExpression, _) => (ne, ne)
- case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+ case (ne: NamedExpression, _) => (ne, ne.toAttribute)
+ case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
}
- val groupMap: Map[Expression, Attribute] =
- namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
-
// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
- case e: Expression if groupMap.contains(e) => groupMap(e)
+ case e: Expression =>
+ namedGroups.collectFirst {
+ case (expr, attr) if expr semanticEquals e => attr
+ }.getOrElse(e)
})
val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
@@ -265,7 +265,7 @@ case class GeneratedAggregate(
val resultProjectionBuilder =
newMutableProjection(
resultExpressions,
- (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+ namedGroups.map(_._2) ++ computationSchema)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
val joinedRow = new JoinedRow3
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 14ecd4e9a7..6898d58441 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
- ignore("cartesian product join") {
+ test("cartesian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::