aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-05-17 13:05:07 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-17 13:05:07 -0700
commit0f576a5748244f7e874b925f8d841f1ca238f087 (patch)
tree364947fa82507553c5e4feb82d8f24b00aeb3969 /python
parente2efe0529acd748f26dbaa41331d1733ed256237 (diff)
downloadspark-0f576a5748244f7e874b925f8d841f1ca238f087.tar.gz
spark-0f576a5748244f7e874b925f8d841f1ca238f087.tar.bz2
spark-0f576a5748244f7e874b925f8d841f1ca238f087.zip
[SPARK-15244] [PYTHON] Type of column name created with createDataFrame is not consistent.
## What changes were proposed in this pull request? **createDataFrame** returns inconsistent types for column names. ```python >>> from pyspark.sql.types import StructType, StructField, StringType >>> schema = StructType([StructField(u"col", StringType())]) >>> df1 = spark.createDataFrame([("a",)], schema) >>> df1.columns # "col" is str ['col'] >>> df2 = spark.createDataFrame([("a",)], [u"col"]) >>> df2.columns # "col" is unicode [u'col'] ``` The reason is only **StructField** has the following code. ``` if not isinstance(name, str): name = name.encode('utf-8') ``` This PR adds the same logic into **createDataFrame** for consistency. ``` if isinstance(schema, list): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] ``` ## How was this patch tested? Pass the Jenkins test (with new python doctest) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13097 from dongjoon-hyun/SPARK-15244.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/session.py2
-rw-r--r--python/pyspark/sql/tests.py7
2 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index ae314359d5..0781b442cb 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -465,6 +465,8 @@ class SparkSession(object):
return (obj, )
schema = StructType().add("value", datatype)
else:
+ if isinstance(schema, list):
+ schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj
if isinstance(data, RDD):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0c73f58c3b..0977c43a39 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -228,6 +228,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
+ def test_column_name_encoding(self):
+ """Ensure that created columns has `str` type consistently."""
+ columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
+ self.assertEqual(columns, ['name', 'age'])
+ self.assertTrue(isinstance(columns[0], str))
+ self.assertTrue(isinstance(columns[1], str))
+
def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]