aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py11
-rw-r--r--python/pyspark/sql/types.py3
2 files changed, 11 insertions, 3 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e86f44281d..1790432edd 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1044,8 +1044,15 @@ class SQLTests(ReusedPySparkTestCase):
self.assertRaises(TypeError, lambda: df[{}])
def test_column_name_with_non_ascii(self):
- df = self.spark.createDataFrame([(1,)], ["数量"])
- self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
+ if sys.version >= '3':
+ columnName = "数量"
+ self.assertTrue(isinstance(columnName, str))
+ else:
+ columnName = unicode("数量", "utf-8")
+ self.assertTrue(isinstance(columnName, unicode))
+ schema = StructType([StructField(columnName, LongType(), True)])
+ df = self.spark.createDataFrame([(1,)], schema)
+ self.assertEqual(schema, df.schema)
self.assertEqual("DataFrame[数量: bigint]", str(df))
self.assertEqual([("数量", 'bigint')], df.dtypes)
self.assertEqual(1, df.select("数量").first()[0])
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 30ab130f29..7d8d0230b4 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -27,7 +27,7 @@ from array import array
if sys.version >= "3":
long = int
- unicode = str
+ basestring = unicode = str
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
@@ -401,6 +401,7 @@ class StructField(DataType):
False
"""
assert isinstance(dataType, DataType), "dataType should be DataType"
+ assert isinstance(name, basestring), "field name should be string"
if not isinstance(name, str):
name = name.encode('utf-8')
self.name = name