diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/cloudpickle.py | 42 | ||||
-rw-r--r-- | python/pyspark/tests.py | 18 |
2 files changed, 54 insertions, 6 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 32dda3888c..bb0783555a 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,6 +52,7 @@ from functools import partial import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new +import dis import traceback import platform @@ -61,6 +62,14 @@ PyImp = platform.python_implementation() import logging cloudLog = logging.getLogger("Cloud.Transport") +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) if PyImp == "PyPy": # register builtin type in `new` @@ -304,16 +313,37 @@ class CloudPickler(pickle.Pickler): write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(code): + def extract_code_globals(co): """ Find all globals names read or written to by codeblock co """ - names = set(code.co_names) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + + if co.co_consts: # see if nested function have any global refs + for const in co.co_consts: if type(const) is types.CodeType: - names |= CloudPickler.extract_code_globals(const) - return names + out_names |= CloudPickler.extract_code_globals(const) + + return out_names def extract_func_data(self, func): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4483bf80db..d1bb2033b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -213,6 +213,24 @@ class SerializationTestCase(unittest.TestCase): out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.func_code.co_names) + ser.dumps(foo) + class PySparkTestCase(unittest.TestCase): |