aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/cloudpickle.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/cloudpickle.py')
-rw-r--r--python/pyspark/cloudpickle.py98
1 files changed, 67 insertions, 31 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index da2b2f3757..959fb8b357 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -43,6 +43,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import print_function
import operator
+import opcode
import os
import io
import pickle
@@ -53,6 +54,8 @@ from functools import partial
import itertools
import dis
import traceback
+import weakref
+
if sys.version < '3':
from pickle import Pickler
@@ -68,10 +71,10 @@ else:
PY3 = True
#relevant opcodes
-STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
-DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
-LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
-GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
+STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
+DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
+LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
+GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
@@ -90,6 +93,43 @@ def _builtin_type(name):
return getattr(types, name)
+if sys.version_info < (3, 4):
+ def _walk_global_ops(code):
+ """
+ Yield (opcode, argument number) tuples for all
+ global-referencing instructions in *code*.
+ """
+ code = getattr(code, 'co_code', b'')
+ if not PY3:
+ code = map(ord, code)
+
+ n = len(code)
+ i = 0
+ extended_arg = 0
+ while i < n:
+ op = code[i]
+ i += 1
+ if op >= HAVE_ARGUMENT:
+ oparg = code[i] + code[i + 1] * 256 + extended_arg
+ extended_arg = 0
+ i += 2
+ if op == EXTENDED_ARG:
+ extended_arg = oparg * 65536
+ if op in GLOBAL_OPS:
+ yield op, oparg
+
+else:
+ def _walk_global_ops(code):
+ """
+ Yield (opcode, argument number) tuples for all
+ global-referencing instructions in *code*.
+ """
+ for instr in dis.get_instructions(code):
+ op = instr.opcode
+ if op in GLOBAL_OPS:
+ yield op, instr.arg
+
+
class CloudPickler(Pickler):
dispatch = Pickler.dispatch.copy()
@@ -260,38 +300,34 @@ class CloudPickler(Pickler):
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple
- @staticmethod
- def extract_code_globals(co):
+ _extract_code_globals_cache = (
+ weakref.WeakKeyDictionary()
+ if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info")
+ else {})
+
+ @classmethod
+ def extract_code_globals(cls, co):
"""
Find all globals names read or written to by codeblock co
"""
- code = co.co_code
- if not PY3:
- code = [ord(c) for c in code]
- names = co.co_names
- out_names = set()
-
- n = len(code)
- i = 0
- extended_arg = 0
- while i < n:
- op = code[i]
+ out_names = cls._extract_code_globals_cache.get(co)
+ if out_names is None:
+ try:
+ names = co.co_names
+ except AttributeError:
+ # PyPy "builtin-code" object
+ out_names = set()
+ else:
+ out_names = set(names[oparg]
+ for op, oparg in _walk_global_ops(co))
- i += 1
- if op >= HAVE_ARGUMENT:
- oparg = code[i] + code[i+1] * 256 + extended_arg
- extended_arg = 0
- i += 2
- if op == EXTENDED_ARG:
- extended_arg = oparg*65536
- if op in GLOBAL_OPS:
- out_names.add(names[oparg])
+ # see if nested function have any global refs
+ if co.co_consts:
+ for const in co.co_consts:
+ if type(const) is types.CodeType:
+ out_names |= cls.extract_code_globals(const)
- # see if nested function have any global refs
- if co.co_consts:
- for const in co.co_consts:
- if type(const) is types.CodeType:
- out_names |= CloudPickler.extract_code_globals(const)
+ cls._extract_code_globals_cache[co] = out_names
return out_names