aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-06-08 23:27:05 -0700
committerReynold Xin <rxin@databricks.com>2015-06-08 23:27:05 -0700
commit7658eb28a2ea28c06e3b5a26f7734a7dc36edc19 (patch)
tree223a4bfef4f1abceeb471bb23fc371fe96a6f5a0
parenta5c52c1a3488b69bec19e460d2d1fdb0c9ada58d (diff)
downloadspark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.gz
spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.bz2
spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.zip
[SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys
JIRA: https://issues.apache.org/jira/browse/SPARK-7990 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6616 from viirya/multi_keys_equi_join and squashes the following commits: cd5c888 [Liang-Chi Hsieh] Import reduce in python3. c43722c [Liang-Chi Hsieh] For comments. 0400e89 [Liang-Chi Hsieh] Fix scala style. cc90015 [Liang-Chi Hsieh] Add methods to facilitate equi-join on multiple joining keys.
-rw-r--r--python/pyspark/sql/dataframe.py45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala9
3 files changed, 75 insertions, 19 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2d8c59518b..e9dd05e2d0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,6 +22,7 @@ import random
if sys.version >= '3':
basestring = unicode = str
long = int
+ from functools import reduce
else:
from itertools import imap as map
@@ -503,36 +504,52 @@ class DataFrame(object):
@ignore_unicode_prefix
@since(1.3)
- def join(self, other, joinExprs=None, joinType=None):
+ def join(self, other, on=None, how=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: a string for join column name, or a join expression (Column).
- If joinExprs is a string indicating the name of the join column,
- the column must exist on both sides, and this performs an inner equi-join.
- :param joinType: str, default 'inner'.
+ :param on: a string for join column name, a list of column names,
+ , a join expression (Column) or a list of Columns.
+ If `on` is a string or a list of string indicating the name of the join column(s),
+ the column(s) must exist on both sides, and this performs an inner equi-join.
+ :param how: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ >>> cond = [df.name == df3.name, df.age == df3.age]
+ >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
+ [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+
>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]
+
+ >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
+ [Row(name=u'Bob', age=5)]
"""
- if joinExprs is None:
+ if on is not None and not isinstance(on, list):
+ on = [on]
+
+ if on is None or len(on) == 0:
jdf = self._jdf.join(other._jdf)
- elif isinstance(joinExprs, basestring):
- jdf = self._jdf.join(other._jdf, joinExprs)
+
+ if isinstance(on[0], basestring):
+ jdf = self._jdf.join(other._jdf, self._jseq(on))
else:
- assert isinstance(joinExprs, Column), "joinExprs should be Column"
- if joinType is None:
- jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ assert isinstance(on[0], Column), "on should be Column or list of Column"
+ if len(on) > 1:
+ on = reduce(lambda x, y: x.__and__(y), on)
+ else:
+ on = on[0]
+ if how is None:
+ jdf = self._jdf.join(other._jdf, on._jc, "inner")
else:
- assert isinstance(joinType, basestring), "joinType should be basestring"
- jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ assert isinstance(how, basestring), "how should be basestring"
+ jdf = self._jdf.join(other._jdf, on._jc, how)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -1315,6 +1332,8 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
+ Row(name='Bob', age=5)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4a224153e1..59f64dd4bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -395,22 +395,50 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def join(right: DataFrame, usingColumn: String): DataFrame = {
+ join(right, Seq(usingColumn))
+ }
+
+ /**
+ * Inner equi-join with another [[DataFrame]] using the given columns.
+ *
+ * Different from other join functions, the join columns will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the columns "user_id" and "user_name"
+ * df1.join(df2, Seq("user_id", "user_name"))
+ * }}}
+ *
+ * Note that if you perform a self-join using this function without aliasing the input
+ * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since
+ * there is no way to disambiguate which side of the join you would like to reference.
+ *
+ * @param right Right side of the join operation.
+ * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
+ * @group dfops
+ * @since 1.4.0
+ */
+ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = {
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sqlContext.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
- // Project only one of the join column.
- val joinedCol = joined.right.resolve(usingColumn)
+ // Project only one of the join columns.
+ val joinedCols = usingColumns.map(col => joined.right.resolve(col))
+ val condition = usingColumns.map { col =>
+ catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col))
+ }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) =>
+ catalyst.expressions.And(cond, eqTo)
+ }
+
Project(
- joined.output.filterNot(_ == joinedCol),
+ joined.output.filterNot(joinedCols.contains(_)),
Join(
joined.left,
joined.right,
joinType = Inner,
- Some(catalyst.expressions.EqualTo(
- joined.left.resolve(usingColumn),
- joined.right.resolve(usingColumn))))
+ condition)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 051d13e9a5..6165764632 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest {
Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
}
+ test("join - join using multiple columns") {
+ val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
+ val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
+
+ checkAnswer(
+ df.join(df2, Seq("int", "int2")),
+ Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
+ }
+
test("join - join using self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")