aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala40
3 files changed, 74 insertions, 12 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ca9bf8efb9..c8c30ce402 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -459,16 +459,23 @@ class DataFrame(object):
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: Join expression
+ :param joinExprs: a string for join column name, or a join expression (Column).
+ If joinExprs is a string indicating the name of the join column,
+ the column must exist on both sides, and this performs an inner equi-join.
:param joinType: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+
+ >>> df.join(df2, 'name').select(df.name, df2.height).collect()
+ [Row(name=u'Bob', height=85)]
"""
if joinExprs is None:
jdf = self._jdf.join(other._jdf)
+ elif isinstance(joinExprs, basestring):
+ jdf = self._jdf.join(other._jdf, joinExprs)
else:
assert isinstance(joinExprs, Column), "joinExprs should be Column"
if joinType is None:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 03d9834d1d..ca6ae482eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -343,6 +343,43 @@ class DataFrame private[sql](
}
/**
+ * Inner equi-join with another [[DataFrame]] using the given column.
+ *
+ * Different from other join functions, the join column will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the column "user_id"
+ * df1.join(df2, "user_id")
+ * }}}
+ *
+ * Note that if you perform a self-join using this function without aliasing the input
+ * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since
+ * there is no way to disambiguate which side of the join you would like to reference.
+ *
+ * @param right Right side of the join operation.
+ * @param usingColumn Name of the column to join on. This column must exist on both sides.
+ * @group dfops
+ */
+ def join(right: DataFrame, usingColumn: String): DataFrame = {
+ // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
+ // by creating a new instance for one of the branch.
+ val joined = sqlContext.executePlan(
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
+
+ // Project only one of the join column.
+ val joinedCol = joined.right.resolve(usingColumn)
+ Project(
+ joined.output.filterNot(_ == joinedCol),
+ Join(
+ joined.left,
+ joined.right,
+ joinType = Inner,
+ Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn))))
+ )
+ }
+
+ /**
* Inner join with another [[DataFrame]], using the given join expression.
*
* {{{
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b9b6a400ae..5ec06d448e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -109,15 +109,6 @@ class DataFrameSuite extends QueryTest {
assert(testData.head(2).head.schema === testData.schema)
}
- test("self join") {
- val df1 = testData.select(testData("key")).as('df1)
- val df2 = testData.select(testData("key")).as('df2)
-
- checkAnswer(
- df1.join(df2, $"df1.key" === $"df2.key"),
- sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
- }
-
test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
@@ -127,8 +118,35 @@ class DataFrameSuite extends QueryTest {
)
}
- test("self join with aliases") {
- val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str")
+ test("join - join using") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
+ val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str")
+
+ checkAnswer(
+ df.join(df2, "int"),
+ Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
+ }
+
+ test("join - join using self join") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
+
+ // self join
+ checkAnswer(
+ df.join(df, "int"),
+ Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil)
+ }
+
+ test("join - self join") {
+ val df1 = testData.select(testData("key")).as('df1)
+ val df2 = testData.select(testData("key")).as('df2)
+
+ checkAnswer(
+ df1.join(df2, $"df1.key" === $"df2.key"),
+ sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
+ }
+
+ test("join - using aliases after self join") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(),
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)