diff options
author | Michael Armbrust <michael@databricks.com> | 2015-05-14 19:49:44 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-05-14 19:49:44 -0700 |
commit | 6d0633e3ec9518278fcc7eba58549d4ad3d5813f (patch) | |
tree | 4bd5a50dcd8af0b37c44180f607dcd8a8e5b26da /python | |
parent | 48fc38f5844f6c12bf440f2990b6d7f1630fafac (diff) | |
download | spark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.tar.gz spark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.tar.bz2 spark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.zip |
[SPARK-7548] [SQL] Add explode function for DataFrames
Add an `explode` function for dataframes and modify the analyzer so that single table generating functions can be present in a select clause along with other expressions. There are currently the following restrictions:
- only top level TGFs are allowed (i.e. no `select(explode('list) + 1)`)
- only one may be present in a single select to avoid potentially confusing implicit Cartesian products.
TODO:
- [ ] Python
Author: Michael Armbrust <michael@databricks.com>
Closes #6107 from marmbrus/explodeFunction and squashes the following commits:
7ee2c87 [Michael Armbrust] whitespace
6f80ba3 [Michael Armbrust] Update dataframe.py
c176c89 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction
81b5da3 [Michael Armbrust] style
d3faa05 [Michael Armbrust] fix self join case
f9e1e3e [Michael Armbrust] fix python, add since
4f0d0a9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction
e710fe4 [Michael Armbrust] add java and python
52ca0dc [Michael Armbrust] [SPARK-7548][SQL] Add explode function for dataframes.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 12 | ||||
-rw-r--r-- | python/pyspark/sql/functions.py | 20 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 15 |
3 files changed, 44 insertions, 3 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 82cb1c2fdb..2ed95ac8e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1511,13 +1511,19 @@ class Column(object): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - def alias(self, alias): - """Return a alias for this column + def alias(self, *alias): + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ - return Column(getattr(self._jc, "as")(alias)) + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) @ignore_unicode_prefix def cast(self, dataType): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d91265ee0b..6cd6974b0e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af6..d37c5dbed7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) |