From 1fd31ba08928f8554f74609f48f4344bd69444e5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 5 May 2015 18:59:46 -0700 Subject: [SPARK-6231][SQL/DF] Automatically resolve join condition ambiguity for self-joins. See the comment in join function for more information. Author: Reynold Xin Closes #5919 from rxin/self-join-resolve and squashes the following commits: e2fb0da [Reynold Xin] Updated SQLConf comment. 7233a86 [Reynold Xin] Updated comment. 6be2b4d [Reynold Xin] Removed println 9f6b72f [Reynold Xin] [SPARK-6231][SQL/DF] Automatically resolve ambiguity in join condition for self-joins. --- .../scala/org/apache/spark/sql/DataFrame.scala | 38 +++++++++- .../main/scala/org/apache/spark/sql/SQLConf.scala | 7 ++ .../org/apache/spark/sql/DataFrameJoinSuite.scala | 86 ++++++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 39 ---------- 4 files changed, 127 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index cf344710ff..aceb7a9627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -416,9 +416,7 @@ class DataFrame private[sql]( * }}} * @group dfops */ - def join(right: DataFrame, joinExprs: Column): DataFrame = { - Join(logicalPlan, right.logicalPlan, joinType = Inner, Some(joinExprs.expr)) - } + def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner") /** * Join with another [[DataFrame]], using the given join expression. The following performs @@ -440,7 +438,39 @@ class DataFrame private[sql]( * @group dfops */ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + // Note that in this function, we introduce a hack in the case of self-join to automatically + // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. + // Consider this case: df.join(df, df("key") === df("key")) + // Since df("key") === df("key") is a trivially true condition, this actually becomes a + // cartesian join. However, most likely users expect to perform a self join using "key". + // With that assumption, this hack turns the trivially true condition into equality on join + // keys that are resolved to both sides. + + // Trigger analysis so in the case of self-join, the analyzer will clone the plan. + // After the cloning, left and right side will have distinct expression ids. + val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + .queryExecution.analyzed.asInstanceOf[Join] + + // If auto self join alias is disabled, return the plan. + if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { + return plan + } + + // If left/right have no output set intersection, return the plan. + val lanalyzed = this.logicalPlan.queryExecution.analyzed + val ranalyzed = right.logicalPlan.queryExecution.analyzed + if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { + return plan + } + + // Otherwise, find the trivially true predicates and automatically resolves them to both sides. + // By the time we get here, since we have already run analysis, all attributes should've been + // resolved and become AttributeReference. + val cond = plan.condition.map { _.transform { + case EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => + EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + }} + plan.copy(condition = cond) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 99db959a87..3ffc2091d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -67,6 +67,10 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + // Whether to automatically resolve ambiguity in join conditions for self-joins. + // See SPARK-6231. + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" object Deprecated { @@ -219,6 +223,9 @@ private[sql] class SQLConf extends Serializable { private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala new file mode 100644 index 0000000000..f005f55b64 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameJoinSuite extends QueryTest { + + test("join - join using") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") + + checkAnswer( + df.join(df2, "int"), + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) + } + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("[SPARK-6231] join - self join auto resolve ambiguity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("key")), + Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + + val left = df.groupBy("key").agg($"key", count("*")) + val right = df.groupBy("key").agg($"key", sum("key")) + checkAnswer( + left.join(right, left("key") === right("key")), + Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff31e15e2d..1515e9b843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.sql class DataFrameSuite extends QueryTest { @@ -118,44 +117,6 @@ class DataFrameSuite extends QueryTest { ) } - test("join - join using") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") - - checkAnswer( - df.join(df2, "int"), - Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) - } - - test("join - join using self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - - // self join - checkAnswer( - df.join(df, "int"), - Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) - } - - test("join - self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) - } - - test("join - using aliases after self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - } - test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = -- cgit v1.2.3