aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorSheamus K. Parkes <shea.parkes@milliman.com>2017-02-14 09:57:43 -0800
committerHolden Karau <holden@us.ibm.com>2017-02-14 09:57:43 -0800
commit7b64f7aa03a49adca5fcafe6fff422823b587514 (patch)
tree07829b371f8008f83f373be0a69954e43cc09d72 /python/pyspark/sql
parente0eeb0f89fffb52cd4d15970bdf00c3c5d1eea88 (diff)
downloadspark-7b64f7aa03a49adca5fcafe6fff422823b587514.tar.gz
spark-7b64f7aa03a49adca5fcafe6fff422823b587514.tar.bz2
spark-7b64f7aa03a49adca5fcafe6fff422823b587514.zip
[SPARK-18541][PYTHON] Add metadata parameter to pyspark.sql.Column.alias()
## What changes were proposed in this pull request? Add a `metadata` keyword parameter to `pyspark.sql.Column.alias()` to allow users to mix-in metadata while manipulating `DataFrame`s in `pyspark`. Without this, I believe it was necessary to pass back through `SparkSession.createDataFrame` each time a user wanted to manipulate `StructField.metadata` in `pyspark`. This pull request also improves consistency between the Scala and Python APIs (i.e. I did not add any functionality that was not already in the Scala API). Discussed ahead of time on JIRA with marmbrus ## How was this patch tested? Added unit tests (and doc tests). Ran the pertinent tests manually. Author: Sheamus K. Parkes <shea.parkes@milliman.com> Closes #16094 from shea-parkes/pyspark-column-alias-metadata.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/column.py26
-rw-r--r--python/pyspark/sql/tests.py10
2 files changed, 33 insertions, 3 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 73c8672eff..0df187a9d3 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -17,6 +17,7 @@
import sys
import warnings
+import json
if sys.version >= '3':
basestring = str
@@ -303,19 +304,38 @@ class Column(object):
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
@since(1.3)
- def alias(self, *alias):
+ def alias(self, *alias, **kwargs):
"""
Returns this column aliased with a new name or names (in the case of expressions that
return more than one column, such as explode).
+ :param alias: strings of desired column names (collects all positional arguments passed)
+ :param metadata: a dict of information to be stored in ``metadata`` attribute of the
+ corresponding :class: `StructField` (optional, keyword only argument)
+
+ .. versionchanged:: 2.2
+ Added optional ``metadata`` argument.
+
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
+ >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max']
+ 99
"""
+ metadata = kwargs.pop('metadata', None)
+ assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs
+
+ sc = SparkContext._active_spark_context
if len(alias) == 1:
- return Column(getattr(self._jc, "as")(alias[0]))
+ if metadata:
+ jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(
+ json.dumps(metadata))
+ return Column(getattr(self._jc, "as")(alias[0], jmeta))
+ else:
+ return Column(getattr(self._jc, "as")(alias[0]))
else:
- sc = SparkContext._active_spark_context
+ if metadata:
+ raise ValueError('metadata can only be provided for a single column')
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 73721674f6..62e1a8c363 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -266,6 +266,9 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")
+ with self.assertRaises(ValueError):
+ data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count()
+
def test_and_in_expression(self):
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
@@ -895,6 +898,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+ def test_column_alias_metadata(self):
+ df = self.df
+ df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'}))
+ self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key')
+ with self.assertRaises(AssertionError):
+ df.select(df.key.alias('pk', metdata={'label': 'Primary Key'}))
+
def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
df = self.sc.parallelize(vals).toDF()