diff options
author | Sheamus K. Parkes <shea.parkes@milliman.com> | 2017-02-14 09:57:43 -0800 |
---|---|---|
committer | Holden Karau <holden@us.ibm.com> | 2017-02-14 09:57:43 -0800 |
commit | 7b64f7aa03a49adca5fcafe6fff422823b587514 (patch) | |
tree | 07829b371f8008f83f373be0a69954e43cc09d72 /python/pyspark/sql | |
parent | e0eeb0f89fffb52cd4d15970bdf00c3c5d1eea88 (diff) | |
download | spark-7b64f7aa03a49adca5fcafe6fff422823b587514.tar.gz spark-7b64f7aa03a49adca5fcafe6fff422823b587514.tar.bz2 spark-7b64f7aa03a49adca5fcafe6fff422823b587514.zip |
[SPARK-18541][PYTHON] Add metadata parameter to pyspark.sql.Column.alias()
## What changes were proposed in this pull request?
Add a `metadata` keyword parameter to `pyspark.sql.Column.alias()` to allow users to mix-in metadata while manipulating `DataFrame`s in `pyspark`. Without this, I believe it was necessary to pass back through `SparkSession.createDataFrame` each time a user wanted to manipulate `StructField.metadata` in `pyspark`.
This pull request also improves consistency between the Scala and Python APIs (i.e. I did not add any functionality that was not already in the Scala API).
Discussed ahead of time on JIRA with marmbrus
## How was this patch tested?
Added unit tests (and doc tests). Ran the pertinent tests manually.
Author: Sheamus K. Parkes <shea.parkes@milliman.com>
Closes #16094 from shea-parkes/pyspark-column-alias-metadata.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/column.py | 26 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 10 |
2 files changed, 33 insertions, 3 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 73c8672eff..0df187a9d3 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -17,6 +17,7 @@ import sys import warnings +import json if sys.version >= '3': basestring = str @@ -303,19 +304,38 @@ class Column(object): isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") @since(1.3) - def alias(self, *alias): + def alias(self, *alias, **kwargs): """ Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). + :param alias: strings of desired column names (collects all positional arguments passed) + :param metadata: a dict of information to be stored in ``metadata`` attribute of the + corresponding :class: `StructField` (optional, keyword only argument) + + .. versionchanged:: 2.2 + Added optional ``metadata`` argument. + >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] + >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max'] + 99 """ + metadata = kwargs.pop('metadata', None) + assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs + + sc = SparkContext._active_spark_context if len(alias) == 1: - return Column(getattr(self._jc, "as")(alias[0])) + if metadata: + jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson( + json.dumps(metadata)) + return Column(getattr(self._jc, "as")(alias[0], jmeta)) + else: + return Column(getattr(self._jc, "as")(alias[0])) else: - sc = SparkContext._active_spark_context + if metadata: + raise ValueError('metadata can only be provided for a single column') return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 73721674f6..62e1a8c363 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,6 +266,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + with self.assertRaises(ValueError): + data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() + def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -895,6 +898,13 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + def test_column_alias_metadata(self): + df = self.df + df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) + self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') + with self.assertRaises(AssertionError): + df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) + def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() |