aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-18 12:08:28 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-18 12:12:55 -0700
commit103c863c2ef3d9e6186cfc7d95251a9515e9f180 (patch)
tree1e8bb956a8ec2a314055dad30220130627f4af45 /sql
parentfc2480ed13742a99470b5012ca3a75ab91e5a5e5 (diff)
downloadspark-103c863c2ef3d9e6186cfc7d95251a9515e9f180.tar.gz
spark-103c863c2ef3d9e6186cfc7d95251a9515e9f180.tar.bz2
spark-103c863c2ef3d9e6186cfc7d95251a9515e9f180.zip
[SPARK-7269] [SQL] Incorrect analysis for aggregation(use semanticEquals)
A modified version of https://github.com/apache/spark/pull/6110, use `semanticEquals` to make it more efficient. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6173 from cloud-fan/7269 and squashes the following commits: e4a3cc7 [Wenchen Fan] address comments cc02045 [Wenchen Fan] consider elements length equal d7ff8f4 [Wenchen Fan] fix 7269
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala18
6 files changed, 48 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b6e1d44b9..dfa4215f2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -142,25 +141,6 @@ class Analyzer(
}
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
- /**
- * Extract attribute set according to the grouping id
- * @param bitmask bitmask to represent the selected of the attribute sequence
- * @param exprs the attributes in sequence
- * @return the attributes of non selected specified via bitmask (with the bit set to 1)
- */
- private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
- : OpenHashSet[Expression] = {
- val set = new OpenHashSet[Expression](2)
-
- var bit = exprs.length - 1
- while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
- bit -= 1
- }
-
- set
- }
-
/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
@@ -197,10 +177,15 @@ class Analyzer(
g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
- val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
+ val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
+ var bit = g.groupByExprs.length - 1
+ while (bit >= 0) {
+ if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
+ bit -= 1
+ }
val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
- case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+ case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f104e742c9..06a0504359 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -86,12 +86,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
- case e: Attribute if !groupingExprs.contains(e) =>
+ case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
- case e if groupingExprs.contains(e) => // OK
+ case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0837a3179d..c7ae9da7fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -76,6 +76,19 @@ abstract class Expression extends TreeNode[Expression] {
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
+
+ /**
+ * Returns true when two expressions will always compute the same result, even if they differ
+ * cosmetically (i.e. capitalization of names in attributes may be different).
+ */
+ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
+ val elements1 = this.productIterator.toSeq
+ val elements2 = other.asInstanceOf[Product].productIterator.toSeq
+ elements1.length == elements2.length && elements1.zip(elements2).forall {
+ case (e1: Expression, e2: Expression) => e1 semanticEquals e2
+ case (i1, i2) => i1 == i2
+ }
+ }
}
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a9170589f8..50be26d0b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -181,6 +181,11 @@ case class AttributeReference(
case _ => false
}
+ override def semanticEquals(other: Expression): Boolean = other match {
+ case ar: AttributeReference => sameRef(ar)
+ case _ => false
+ }
+
override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cd54d04814..1dd75a8846 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -159,9 +159,10 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
+ val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
- .get(e.transform { case Alias(g: ExtractValue, _) => g })
- .map(_.toAttribute)
+ .find { case (k, v) => k semanticEquals trimmed }
+ .map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ca2c4b4019..e60d00e635 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
}
+
+ test("SPARK-7269 Check analysis failed in case in-sensitive") {
+ Seq(1, 2, 3).map { i =>
+ (i.toString, i.toString)
+ }.toDF("key", "value").registerTempTable("df_analysis")
+ sql("SELECT kEy from df_analysis group by key").collect()
+ sql("SELECT kEy+3 from df_analysis group by key+3").collect()
+ sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
+ sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
+ sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
+ sql("SELECT 2 from df_analysis A group by key+1").collect()
+ intercept[AnalysisException] {
+ sql("SELECT kEy+1 from df_analysis group by key+3")
+ }
+ intercept[AnalysisException] {
+ sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
+ }
+ }
}