diff options
author | Josh Rosen <joshrosen@apache.org> | 2013-11-28 23:44:56 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@apache.org> | 2013-11-28 23:44:56 -0800 |
commit | 3787f514d9a8e45d2c257b4696e30bc1a1935748 (patch) | |
tree | 572553edf58b4d97b54afe1a536f30288bc1db4f /python | |
parent | 743a31a7ca4421cbbd5b615b773997a06a7ab4ee (diff) | |
download | spark-3787f514d9a8e45d2c257b4696e30bc1a1935748.tar.gz spark-3787f514d9a8e45d2c257b4696e30bc1a1935748.tar.bz2 spark-3787f514d9a8e45d2c257b4696e30bc1a1935748.zip |
Fix UnicodeEncodeError in PySpark saveAsTextFile().
Fixes SPARK-970.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/rdd.py | 5 | ||||
-rw-r--r-- | python/pyspark/tests.py | 15 |
2 files changed, 19 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 957f3f89c0..d8da02072c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -605,7 +605,10 @@ class RDD(object): '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ def func(split, iterator): - return (str(x).encode("utf-8") for x in iterator) + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 621e1cb58c..3987642bf4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,8 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ +from fileinput import input +from glob import glob import os import shutil import sys @@ -138,6 +140,19 @@ class TestAddFile(PySparkTestCase): self.assertEqual("Hello World from inside a package!", UserClass().hello()) +class TestRDDFunctions(PySparkTestCase): + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) + self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): |