diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-07-29 01:02:18 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@apache.org> | 2014-07-29 01:02:18 -0700 |
commit | 92ef02626e793ea853cced4cbfee316f0b748ed7 (patch) | |
tree | c88aa9fc7fe29e5293a59c479a6535ba6deb404e /python | |
parent | ccd5ab5f82812abc2eb518448832cc20fb903345 (diff) | |
download | spark-92ef02626e793ea853cced4cbfee316f0b748ed7.tar.gz spark-92ef02626e793ea853cced4cbfee316f0b748ed7.tar.bz2 spark-92ef02626e793ea853cced4cbfee316f0b748ed7.zip |
[SPARK-791] [PySpark] fix pickle itemgetter with cloudpickle
fix the problem with pickle operator.itemgetter with multiple index.
Author: Davies Liu <davies.liu@gmail.com>
Closes #1627 from davies/itemgetter and squashes the following commits:
aabd7fa [Davies Liu] fix pickle itemgetter with cloudpickle
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/cloudpickle.py | 5 | ||||
-rw-r--r-- | python/pyspark/tests.py | 6 |
2 files changed, 9 insertions, 2 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 4fda2a9b95..68062483de 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -560,8 +560,9 @@ class CloudPickler(pickle.Pickler): ] - itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,)) + obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents + return self.save_reduce(operator.itemgetter, + obj.item if obj.nitems > 1 else (obj.item,)) if PyObject_HEAD: dispatch[operator.itemgetter] = save_itemgetter diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6dee7dc66c..8486c8595b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -284,6 +284,12 @@ class TestRDDFunctions(PySparkTestCase): self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + class TestIO(PySparkTestCase): |