diff options
author | Michael Armbrust <michael@databricks.com> | 2015-02-11 12:31:56 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-02-11 12:31:56 -0800 |
commit | a60d2b70adff3a8fb3bdfac226b1d86fdb443da4 (patch) | |
tree | 92f50bd56e8ffc48b77d7845585d15327f169431 /sql/catalyst | |
parent | 03bf704bf442ac7dd960795295b51957ce972491 (diff) | |
download | spark-a60d2b70adff3a8fb3bdfac226b1d86fdb443da4.tar.gz spark-a60d2b70adff3a8fb3bdfac226b1d86fdb443da4.tar.bz2 spark-a60d2b70adff3a8fb3bdfac226b1d86fdb443da4.zip |
[SPARK-5454] More robust handling of self joins
Also I fix a bunch of bad output in test cases.
Author: Michael Armbrust <michael@databricks.com>
Closes #4520 from marmbrus/selfJoin and squashes the following commits:
4f4a85c [Michael Armbrust] comments
49c8e26 [Michael Armbrust] fix tests
6fc38de [Michael Armbrust] fix style
55d64b3 [Michael Armbrust] fix dataframe selfjoins
Diffstat (limited to 'sql/catalyst')
3 files changed, 24 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3f0d77ad63..2d1fa106a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -53,14 +53,11 @@ class Analyzer(catalog: Catalog, val extendedRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( - Batch("MultiInstanceRelations", Once, - NewRelationInstances), Batch("Resolution", fixedPoint, - ResolveReferences :: ResolveRelations :: + ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: - NewRelationInstances :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: @@ -285,6 +282,27 @@ class Analyzer(catalog: Catalog, } ) + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + + val (oldRelation, newRelation, attributeRewrites) = right.collect { + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output)) + (oldVersion, newVersion, newAttributes) + }.head // Only handle first case found, others will be fixed on the next pass. + + val newRight = right transformUp { + case r if r == oldRelation => newRelation + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + + j.copy(right = newRight) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 4c5fb3f45b..894c3500cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * produced by distinct operators in a query tree as this breaks the guarantee that expression * ids, which are used to differentiate attributes, are unique. * - * Before analysis, all operators that include this trait will be asked to produce a new version + * During analysis, operators that include this trait may be asked to produce a new version * of itself with globally unique expression ids. */ trait MultiInstanceRelation { def newInstance(): this.type } - -/** - * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so - * that each instance has unique expression ids for the attributes produced. - */ -object NewRelationInstances extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val localRelations = plan collect { case l: MultiInstanceRelation => l} - val multiAppearance = localRelations - .groupBy(identity[MultiInstanceRelation]) - .filter { case (_, ls) => ls.size > 1 } - .map(_._1) - .toSet - - plan transform { - case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c4a1f899d8..7d609b9138 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -33,11 +33,9 @@ class PlanTest extends FunSuite { * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min plan transformAllExpressions { case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) } } |