aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-07-29 01:02:18 -0700
committerJosh Rosen <joshrosen@apache.org>2014-07-29 01:02:18 -0700
commit92ef02626e793ea853cced4cbfee316f0b748ed7 (patch)
treec88aa9fc7fe29e5293a59c479a6535ba6deb404e /python/pyspark
parentccd5ab5f82812abc2eb518448832cc20fb903345 (diff)
downloadspark-92ef02626e793ea853cced4cbfee316f0b748ed7.tar.gz
spark-92ef02626e793ea853cced4cbfee316f0b748ed7.tar.bz2
spark-92ef02626e793ea853cced4cbfee316f0b748ed7.zip
[SPARK-791] [PySpark] fix pickle itemgetter with cloudpickle
fix the problem with pickle operator.itemgetter with multiple index. Author: Davies Liu <davies.liu@gmail.com> Closes #1627 from davies/itemgetter and squashes the following commits: aabd7fa [Davies Liu] fix pickle itemgetter with cloudpickle
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/cloudpickle.py5
-rw-r--r--python/pyspark/tests.py6
2 files changed, 9 insertions, 2 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 4fda2a9b95..68062483de 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -560,8 +560,9 @@ class CloudPickler(pickle.Pickler):
]
- itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
- return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,))
+ obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
+ return self.save_reduce(operator.itemgetter,
+ obj.item if obj.nitems > 1 else (obj.item,))
if PyObject_HEAD:
dispatch[operator.itemgetter] = save_itemgetter
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6dee7dc66c..8486c8595b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -284,6 +284,12 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual(set([2]), sets[3])
self.assertEqual(set([1, 3]), sets[5])
+ def test_itemgetter(self):
+ rdd = self.sc.parallelize([range(10)])
+ from operator import itemgetter
+ self.assertEqual([1], rdd.map(itemgetter(1)).collect())
+ self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
+
class TestIO(PySparkTestCase):