aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-06-08 23:27:05 -0700
committerReynold Xin <rxin@databricks.com>2015-06-08 23:27:05 -0700
commit7658eb28a2ea28c06e3b5a26f7734a7dc36edc19 (patch)
tree223a4bfef4f1abceeb471bb23fc371fe96a6f5a0 /python/pyspark/sql/dataframe.py
parenta5c52c1a3488b69bec19e460d2d1fdb0c9ada58d (diff)
downloadspark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.gz
spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.bz2
spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.zip
[SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys
JIRA: https://issues.apache.org/jira/browse/SPARK-7990 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6616 from viirya/multi_keys_equi_join and squashes the following commits: cd5c888 [Liang-Chi Hsieh] Import reduce in python3. c43722c [Liang-Chi Hsieh] For comments. 0400e89 [Liang-Chi Hsieh] Fix scala style. cc90015 [Liang-Chi Hsieh] Add methods to facilitate equi-join on multiple joining keys.
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py45
1 files changed, 32 insertions, 13 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2d8c59518b..e9dd05e2d0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,6 +22,7 @@ import random
if sys.version >= '3':
basestring = unicode = str
long = int
+ from functools import reduce
else:
from itertools import imap as map
@@ -503,36 +504,52 @@ class DataFrame(object):
@ignore_unicode_prefix
@since(1.3)
- def join(self, other, joinExprs=None, joinType=None):
+ def join(self, other, on=None, how=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: a string for join column name, or a join expression (Column).
- If joinExprs is a string indicating the name of the join column,
- the column must exist on both sides, and this performs an inner equi-join.
- :param joinType: str, default 'inner'.
+ :param on: a string for join column name, a list of column names,
+ , a join expression (Column) or a list of Columns.
+ If `on` is a string or a list of string indicating the name of the join column(s),
+ the column(s) must exist on both sides, and this performs an inner equi-join.
+ :param how: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ >>> cond = [df.name == df3.name, df.age == df3.age]
+ >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
+ [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+
>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]
+
+ >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
+ [Row(name=u'Bob', age=5)]
"""
- if joinExprs is None:
+ if on is not None and not isinstance(on, list):
+ on = [on]
+
+ if on is None or len(on) == 0:
jdf = self._jdf.join(other._jdf)
- elif isinstance(joinExprs, basestring):
- jdf = self._jdf.join(other._jdf, joinExprs)
+
+ if isinstance(on[0], basestring):
+ jdf = self._jdf.join(other._jdf, self._jseq(on))
else:
- assert isinstance(joinExprs, Column), "joinExprs should be Column"
- if joinType is None:
- jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ assert isinstance(on[0], Column), "on should be Column or list of Column"
+ if len(on) > 1:
+ on = reduce(lambda x, y: x.__and__(y), on)
+ else:
+ on = on[0]
+ if how is None:
+ jdf = self._jdf.join(other._jdf, on._jc, "inner")
else:
- assert isinstance(joinType, basestring), "joinType should be basestring"
- jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ assert isinstance(how, basestring), "how should be basestring"
+ jdf = self._jdf.join(other._jdf, on._jc, how)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -1315,6 +1332,8 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
+ Row(name='Bob', age=5)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),