aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-12-17 08:04:11 -0800
committerDavies Liu <davies.liu@gmail.com>2015-12-17 08:04:11 -0800
commita170d34a1b309fecc76d1370063e0c4f44dc2142 (patch)
treeb0371b08ebe08ec569e67072b29e74fd04bffc47
parentcd3d937b0cc89cb5e4098f7d9f5db2712e3de71e (diff)
downloadspark-a170d34a1b309fecc76d1370063e0c4f44dc2142.tar.gz
spark-a170d34a1b309fecc76d1370063e0c4f44dc2142.tar.bz2
spark-a170d34a1b309fecc76d1370063e0c4f44dc2142.zip
[SPARK-12395] [SQL] fix resulting columns of outer join
For API DataFrame.join(right, usingColumns, joinType), if the joinType is right_outer or full_outer, the resulting join columns could be wrong (will be null). The order of columns had been changed to match that with MySQL and PostgreSQL [1]. This PR also fix the nullability of output for outer join. [1] http://www.postgresql.org/docs/9.2/static/queries-table-expressions.html Author: Davies Liu <davies@databricks.com> Closes #10353 from davies/fix_join.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala20
2 files changed, 36 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 6250e95216..d741312314 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
@@ -455,10 +455,8 @@ class DataFrame private[sql](
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sqlContext.executePlan(
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join]
- // Project only one of the join columns.
- val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col))
val condition = usingColumns.map { col =>
catalyst.expressions.EqualTo(
withPlan(joined.left).resolve(col),
@@ -467,9 +465,26 @@ class DataFrame private[sql](
catalyst.expressions.And(cond, eqTo)
}
+ // Project only one of the join columns.
+ val joinedCols = JoinType(joinType) match {
+ case Inner | LeftOuter | LeftSemi =>
+ usingColumns.map(col => withPlan(joined.left).resolve(col))
+ case RightOuter =>
+ usingColumns.map(col => withPlan(joined.right).resolve(col))
+ case FullOuter =>
+ usingColumns.map { col =>
+ val leftCol = withPlan(joined.left).resolve(col)
+ val rightCol = withPlan(joined.right).resolve(col)
+ Alias(Coalesce(Seq(leftCol, rightCol)), col)()
+ }
+ }
+ // The nullability of output of joined could be different than original column,
+ // so we can only compare them by exprId
+ val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil)
+ val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId))
withPlan {
Project(
- joined.output.filterNot(joinedCols.contains(_)),
+ resultCols,
Join(
joined.left,
joined.right,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c70397f985..39a65413bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -43,16 +43,28 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
}
test("join - join using multiple columns and specifying join type") {
- val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
- val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
+ val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
+ val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
+
+ checkAnswer(
+ df.join(df2, Seq("int", "str"), "inner"),
+ Row(1, "1", 2, 3) :: Nil)
checkAnswer(
df.join(df2, Seq("int", "str"), "left"),
- Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil)
+ Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil)
checkAnswer(
df.join(df2, Seq("int", "str"), "right"),
- Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil)
+ Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil)
+
+ checkAnswer(
+ df.join(df2, Seq("int", "str"), "outer"),
+ Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil)
+
+ checkAnswer(
+ df.join(df2, Seq("int", "str"), "left_semi"),
+ Row(1, "1", 2) :: Nil)
}
test("join - join using self join") {