aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-03-17 19:47:51 -0700
committerMichael Armbrust <michael@databricks.com>2015-03-17 19:47:51 -0700
commit3579003115fa3217cff6aa400729d96b0c7b257b (patch)
treebad3f848417dfe25c37dff200fb77d0083d13019 /sql/catalyst
parenta6ee2f7940b9a64a81667615586ae597da837974 (diff)
downloadspark-3579003115fa3217cff6aa400729d96b0c7b257b.tar.gz
spark-3579003115fa3217cff6aa400729d96b0c7b257b.tar.bz2
spark-3579003115fa3217cff6aa400729d96b0c7b257b.zip
[SPARK-6247][SQL] Fix resolution of ambiguous joins caused by new aliases
We need to handle ambiguous `exprId`s that are produced by new aliases as well as those caused by leaf nodes (`MultiInstanceRelation`). Attempting to fix this revealed a bug in `equals` for `Alias` as these objects were comparing equal even when the expression ids did not match. Additionally, `LocalRelation` did not correctly provide statistics, and some tests in `catalyst` and `hive` were not using the helper functions for comparing plans. Based on #4991 by chenghao-intel Author: Michael Armbrust <michael@databricks.com> Closes #5062 from marmbrus/selfJoins and squashes the following commits: 8e9b84b [Michael Armbrust] check qualifier too 8038a36 [Michael Armbrust] handle aggs too 0b9c687 [Michael Armbrust] fix more tests c3c574b [Michael Armbrust] revert change. 725f1ab [Michael Armbrust] add statistics a925d08 [Michael Armbrust] check for conflicting attributes in join resolution b022ef7 [Michael Armbrust] Handle project aliases. d8caa40 [Michael Armbrust] test case: SPARK-6247 f9c67c2 [Michael Armbrust] Check for duplicate attributes in join resolution. 898af73 [Michael Armbrust] Fix Alias equality.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala11
6 files changed, 57 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7753331748..92d3db077c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -237,22 +237,33 @@ class Analyzer(catalog: Catalog,
// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
- val (oldRelation, newRelation, attributeRewrites) = right.collect {
+ val (oldRelation, newRelation) = right.collect {
+ // Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
- val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output))
- (oldVersion, newVersion, newAttributes)
+ (oldVersion, newVersion)
+
+ // Handle projects that create conflicting aliases.
+ case oldVersion @ Project(projectList, _)
+ if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
+
+ case oldVersion @ Aggregate(_, aggregateExpressions, _)
+ if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
}.head // Only handle first case found, others will be fixed on the next pass.
+ val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp {
case r if r == oldRelation => newRelation
+ } transformUp {
case other => other transformExpressions {
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
}
}
-
j.copy(right = newRight)
case q: LogicalPlan =>
@@ -272,6 +283,17 @@ class Analyzer(catalog: Catalog,
}
}
+ def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
+ expressions.map {
+ case a: Alias => Alias(a.child, a.name)()
+ case other => other
+ }
+ }
+
+ def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
+ AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
+ }
+
/**
* Returns true if `exprs` contains a [[Star]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 62c062be6d..17f7f9fe51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -124,6 +124,12 @@ case class Alias(child: Expression, name: String)
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
+
+ override def equals(other: Any): Boolean = other match {
+ case a: Alias =>
+ name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers
+ case _ => false
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 92bd057c6f..bb79dc3405 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -54,4 +54,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil)
otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
case _ => false
}
+
+ override lazy val statistics =
+ Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 624912dab4..1e7b449d75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -108,6 +108,13 @@ case class Join(
left.output ++ right.output
}
}
+
+ def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty
+
+ // Joins are only resolved if they don't introduce ambiguious expression ids.
+ override lazy val resolved: Boolean = {
+ childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved
+ }
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 85798d0871..ecbb54218d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._
-class HiveTypeCoercionSuite extends FunSuite {
+class HiveTypeCoercionSuite extends PlanTest {
test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
@@ -106,7 +106,8 @@ class HiveTypeCoercionSuite extends FunSuite {
val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
- assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+ comparePlans(
+ booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
// Remove superflous boolean -> boolean casts.
@@ -119,7 +120,8 @@ class HiveTypeCoercionSuite extends FunSuite {
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
- assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+ comparePlans(
+ fac(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
ruleTest(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 7d609b9138..48884040bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.plans
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
/**
@@ -36,6 +36,8 @@ class PlanTest extends FunSuite {
plan transformAllExpressions {
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
+ case a: Alias =>
+ Alias(a.child, a.name)(exprId = ExprId(0))
}
}
@@ -50,4 +52,9 @@ class PlanTest extends FunSuite {
|${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
""".stripMargin)
}
+
+ /** Fails the test if the two expressions do not match */
+ protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
+ comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation))
+ }
}