diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-06-08 23:27:05 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-06-08 23:27:05 -0700 |
commit | 7658eb28a2ea28c06e3b5a26f7734a7dc36edc19 (patch) | |
tree | 223a4bfef4f1abceeb471bb23fc371fe96a6f5a0 /python | |
parent | a5c52c1a3488b69bec19e460d2d1fdb0c9ada58d (diff) | |
download | spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.gz spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.tar.bz2 spark-7658eb28a2ea28c06e3b5a26f7734a7dc36edc19.zip |
[SPARK-7990][SQL] Add methods to facilitate equi-join on multiple joining keys
JIRA: https://issues.apache.org/jira/browse/SPARK-7990
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #6616 from viirya/multi_keys_equi_join and squashes the following commits:
cd5c888 [Liang-Chi Hsieh] Import reduce in python3.
c43722c [Liang-Chi Hsieh] For comments.
0400e89 [Liang-Chi Hsieh] Fix scala style.
cc90015 [Liang-Chi Hsieh] Add methods to facilitate equi-join on multiple joining keys.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d8c59518b..e9dd05e2d0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,6 +22,7 @@ import random if sys.version >= '3': basestring = unicode = str long = int + from functools import reduce else: from itertools import imap as map @@ -503,36 +504,52 @@ class DataFrame(object): @ignore_unicode_prefix @since(1.3) - def join(self, other, joinExprs=None, joinType=None): + def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: a string for join column name, or a join expression (Column). - If joinExprs is a string indicating the name of the join column, - the column must exist on both sides, and this performs an inner equi-join. - :param joinType: str, default 'inner'. + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] """ - if joinExprs is None: + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - elif isinstance(joinExprs, basestring): - jdf = self._jdf.join(other._jdf, joinExprs) + + if isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -1315,6 +1332,8 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), |