diff options
4 files changed, 25 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8c1e4d74f9..0b9f621fed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] { * cosmetically (i.e. capitalization of names in attributes may be different). */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (i1, i2) => i1 == i2 - } + checkSemantic(elements1, elements2) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a8846..3b6f8bfd9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = + val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) - }.toMap + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -160,17 +160,15 @@ object PartialAggregation { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } - namedGroupingExpressions - .find { case (k, v) => k semanticEquals trimmed } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + val partialComputation = namedGroupingExpressions.map(_._2) ++ + partialEvaluations.values.flatMap(_.partialEvaluations) - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index af3791734d..1c40a9209f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -214,18 +214,18 @@ case class GeneratedAggregate( }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) } - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) + case e: Expression => + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) @@ -265,7 +265,7 @@ case class GeneratedAggregate( val resultProjectionBuilder = newMutableProjection( resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 14ecd4e9a7..6898d58441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) :: |