diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-05-17 13:05:07 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-05-17 13:05:07 -0700 |
commit | 0f576a5748244f7e874b925f8d841f1ca238f087 (patch) | |
tree | 364947fa82507553c5e4feb82d8f24b00aeb3969 | |
parent | e2efe0529acd748f26dbaa41331d1733ed256237 (diff) | |
download | spark-0f576a5748244f7e874b925f8d841f1ca238f087.tar.gz spark-0f576a5748244f7e874b925f8d841f1ca238f087.tar.bz2 spark-0f576a5748244f7e874b925f8d841f1ca238f087.zip |
[SPARK-15244] [PYTHON] Type of column name created with createDataFrame is not consistent.
## What changes were proposed in this pull request?
**createDataFrame** returns inconsistent types for column names.
```python
>>> from pyspark.sql.types import StructType, StructField, StringType
>>> schema = StructType([StructField(u"col", StringType())])
>>> df1 = spark.createDataFrame([("a",)], schema)
>>> df1.columns # "col" is str
['col']
>>> df2 = spark.createDataFrame([("a",)], [u"col"])
>>> df2.columns # "col" is unicode
[u'col']
```
The reason is only **StructField** has the following code.
```
if not isinstance(name, str):
name = name.encode('utf-8')
```
This PR adds the same logic into **createDataFrame** for consistency.
```
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
```
## How was this patch tested?
Pass the Jenkins test (with new python doctest)
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #13097 from dongjoon-hyun/SPARK-15244.
-rw-r--r-- | python/pyspark/sql/session.py | 2 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 7 |
2 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ae314359d5..0781b442cb 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -465,6 +465,8 @@ class SparkSession(object): return (obj, ) schema = StructType().add("value", datatype) else: + if isinstance(schema, list): + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj if isinstance(data, RDD): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0c73f58c3b..0977c43a39 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -228,6 +228,13 @@ class SQLTests(ReusedPySparkTestCase): self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + def test_column_name_encoding(self): + """Ensure that created columns has `str` type consistently.""" + columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns + self.assertEqual(columns, ['name', 'age']) + self.assertTrue(isinstance(columns[0], str)) + self.assertTrue(isinstance(columns[1], str)) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] |