aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-12 18:42:50 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-12 18:42:50 -0700
commit71af030b46a89aaa9a87f18f56b9e1f1cd8ce2e7 (patch)
tree1cff9839428a177b129de830c3c175b1316a6626 /python
parent25311c2c545a60eb9dcf704814d4600987852155 (diff)
downloadspark-71af030b46a89aaa9a87f18f56b9e1f1cd8ce2e7.tar.gz
spark-71af030b46a89aaa9a87f18f56b9e1f1cd8ce2e7.tar.bz2
spark-71af030b46a89aaa9a87f18f56b9e1f1cd8ce2e7.zip
[SPARK-3094] [PySpark] compatitable with PyPy
After this patch, we can run PySpark in PyPy (testing with PyPy 2.3.1 in Mac 10.9), for example: ``` PYSPARK_PYTHON=pypy ./bin/spark-submit wordcount.py ``` The performance speed up will depend on work load (from 20% to 3000%). Here are some benchmarks: Job | CPython 2.7 | PyPy 2.3.1 | Speed up ------- | ------------ | ------------- | ------- Word Count | 41s | 15s | 2.7x Sort | 46s | 44s | 1.05x Stats | 174s | 3.6s | 48x Here is the code used for benchmark: ```python rdd = sc.textFile("text") def wordcount(): rdd.flatMap(lambda x:x.split('/'))\ .map(lambda x:(x,1)).reduceByKey(lambda x,y:x+y).collectAsMap() def sort(): rdd.sortBy(lambda x:x, 1).count() def stats(): sc.parallelize(range(1024), 20).flatMap(lambda x: xrange(5024)).stats() ``` Author: Davies Liu <davies.liu@gmail.com> Closes #2144 from davies/pypy and squashes the following commits: 9aed6c5 [Davies Liu] use protocol 2 in CloudPickle 4bc1f04 [Davies Liu] refactor b20ab3a [Davies Liu] pickle sys.stdout and stderr in portable way 3ca2351 [Davies Liu] Merge branch 'master' into pypy fae8b19 [Davies Liu] improve attrgetter, add tests 591f830 [Davies Liu] try to run tests with PyPy in run-tests c8d62ba [Davies Liu] cleanup f651fd0 [Davies Liu] fix tests using array with PyPy 1b98fb3 [Davies Liu] serialize itemgetter/attrgetter in portable ways 3c1dbfe [Davies Liu] Merge branch 'master' into pypy 42fb5fa [Davies Liu] Merge branch 'master' into pypy cb2d724 [Davies Liu] fix tests 9986692 [Davies Liu] Merge branch 'master' into pypy 25b4ca7 [Davies Liu] support PyPy
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/cloudpickle.py168
-rw-r--r--python/pyspark/daemon.py6
-rw-r--r--python/pyspark/serializers.py10
-rw-r--r--python/pyspark/tests.py85
-rwxr-xr-xpython/run-tests21
5 files changed, 172 insertions, 118 deletions
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 80e51d1a58..32dda3888c 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -52,35 +52,19 @@ from functools import partial
import itertools
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import new
-import dis
import traceback
+import platform
-#relevant opcodes
-STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
-DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
-LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
-GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
+PyImp = platform.python_implementation()
-HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
-EXTENDED_ARG = chr(dis.EXTENDED_ARG)
import logging
cloudLog = logging.getLogger("Cloud.Transport")
-try:
- import ctypes
-except (MemoryError, ImportError):
- logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True)
- ctypes = None
- PyObject_HEAD = None
-else:
-
- # for reading internal structures
- PyObject_HEAD = [
- ('ob_refcnt', ctypes.c_size_t),
- ('ob_type', ctypes.c_void_p),
- ]
+if PyImp == "PyPy":
+ # register builtin type in `new`
+ new.method = types.MethodType
try:
from cStringIO import StringIO
@@ -225,6 +209,8 @@ class CloudPickler(pickle.Pickler):
if themodule:
self.modules.add(themodule)
+ if getattr(themodule, name, None) is obj:
+ return self.save_global(obj, name)
if not self.savedDjangoEnv:
#hack for django - if we detect the settings module, we transport it
@@ -306,44 +292,28 @@ class CloudPickler(pickle.Pickler):
# create a skeleton function object and memoize it
save(_make_skel_func)
- save((code, len(closure), base_globals))
+ save((code, closure, base_globals))
write(pickle.REDUCE)
self.memoize(func)
# save the rest of the func data needed by _fill_function
save(f_globals)
save(defaults)
- save(closure)
save(dct)
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple
@staticmethod
- def extract_code_globals(co):
+ def extract_code_globals(code):
"""
Find all globals names read or written to by codeblock co
"""
- code = co.co_code
- names = co.co_names
- out_names = set()
-
- n = len(code)
- i = 0
- extended_arg = 0
- while i < n:
- op = code[i]
-
- i = i+1
- if op >= HAVE_ARGUMENT:
- oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
- extended_arg = 0
- i = i+2
- if op == EXTENDED_ARG:
- extended_arg = oparg*65536L
- if op in GLOBAL_OPS:
- out_names.add(names[oparg])
- #print 'extracted', out_names, ' from ', names
- return out_names
+ names = set(code.co_names)
+ if code.co_consts: # see if nested function have any global refs
+ for const in code.co_consts:
+ if type(const) is types.CodeType:
+ names |= CloudPickler.extract_code_globals(const)
+ return names
def extract_func_data(self, func):
"""
@@ -354,10 +324,7 @@ class CloudPickler(pickle.Pickler):
# extract all global ref's
func_global_refs = CloudPickler.extract_code_globals(code)
- if code.co_consts: # see if nested function have any global refs
- for const in code.co_consts:
- if type(const) is types.CodeType and const.co_names:
- func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const))
+
# process all variables referenced by global environment
f_globals = {}
for var in func_global_refs:
@@ -396,6 +363,12 @@ class CloudPickler(pickle.Pickler):
return (code, f_globals, defaults, closure, dct, base_globals)
+ def save_builtin_function(self, obj):
+ if obj.__module__ is "__builtin__":
+ return self.save_global(obj)
+ return self.save_function(obj)
+ dispatch[types.BuiltinFunctionType] = save_builtin_function
+
def save_global(self, obj, name=None, pack=struct.pack):
write = self.write
memo = self.memo
@@ -435,7 +408,7 @@ class CloudPickler(pickle.Pickler):
try:
klass = getattr(themodule, name)
except AttributeError, a:
- #print themodule, name, obj, type(obj)
+ # print themodule, name, obj, type(obj)
raise pickle.PicklingError("Can't pickle builtin %s" % obj)
else:
raise
@@ -480,7 +453,6 @@ class CloudPickler(pickle.Pickler):
write(pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
dispatch[types.ClassType] = save_global
- dispatch[types.BuiltinFunctionType] = save_global
dispatch[types.TypeType] = save_global
def save_instancemethod(self, obj):
@@ -551,23 +523,39 @@ class CloudPickler(pickle.Pickler):
dispatch[property] = save_property
def save_itemgetter(self, obj):
- """itemgetter serializer (needed for namedtuple support)
- a bit of a pain as we need to read ctypes internals"""
- class ItemGetterType(ctypes.Structure):
- _fields_ = PyObject_HEAD + [
- ('nitems', ctypes.c_size_t),
- ('item', ctypes.py_object)
- ]
-
-
- obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
- return self.save_reduce(operator.itemgetter,
- obj.item if obj.nitems > 1 else (obj.item,))
-
- if PyObject_HEAD:
+ """itemgetter serializer (needed for namedtuple support)"""
+ class Dummy:
+ def __getitem__(self, item):
+ return item
+ items = obj(Dummy())
+ if not isinstance(items, tuple):
+ items = (items, )
+ return self.save_reduce(operator.itemgetter, items)
+
+ if type(operator.itemgetter) is type:
dispatch[operator.itemgetter] = save_itemgetter
+ def save_attrgetter(self, obj):
+ """attrgetter serializer"""
+ class Dummy(object):
+ def __init__(self, attrs, index=None):
+ self.attrs = attrs
+ self.index = index
+ def __getattribute__(self, item):
+ attrs = object.__getattribute__(self, "attrs")
+ index = object.__getattribute__(self, "index")
+ if index is None:
+ index = len(attrs)
+ attrs.append(item)
+ else:
+ attrs[index] = ".".join([attrs[index], item])
+ return type(self)(attrs, index)
+ attrs = []
+ obj(Dummy(attrs))
+ return self.save_reduce(operator.attrgetter, tuple(attrs))
+ if type(operator.attrgetter) is type:
+ dispatch[operator.attrgetter] = save_attrgetter
def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
@@ -660,11 +648,11 @@ class CloudPickler(pickle.Pickler):
if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
- if obj.name == '<stdout>':
+ if obj is sys.stdout:
return self.save_reduce(getattr, (sys,'stdout'), obj=obj)
- if obj.name == '<stderr>':
+ if obj is sys.stderr:
return self.save_reduce(getattr, (sys,'stderr'), obj=obj)
- if obj.name == '<stdin>':
+ if obj is sys.stdin:
raise pickle.PicklingError("Cannot pickle standard input")
if hasattr(obj, 'isatty') and obj.isatty():
raise pickle.PicklingError("Cannot pickle files that map to tty objects")
@@ -873,8 +861,7 @@ def _genpartial(func, args, kwds):
kwds = {}
return partial(func, *args, **kwds)
-
-def _fill_function(func, globals, defaults, closure, dict):
+def _fill_function(func, globals, defaults, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func().
"""
@@ -882,49 +869,28 @@ def _fill_function(func, globals, defaults, closure, dict):
func.func_defaults = defaults
func.func_dict = dict
- if len(closure) != len(func.func_closure):
- raise pickle.UnpicklingError("closure lengths don't match up")
- for i in range(len(closure)):
- _change_cell_value(func.func_closure[i], closure[i])
-
return func
-def _make_skel_func(code, num_closures, base_globals = None):
+def _make_cell(value):
+ return (lambda: value).func_closure[0]
+
+def _reconstruct_closure(values):
+ return tuple([_make_cell(v) for v in values])
+
+def _make_skel_func(code, closures, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
- #build closure (cells):
- if not ctypes:
- raise Exception('ctypes failed to import; cannot build function')
-
- cellnew = ctypes.pythonapi.PyCell_New
- cellnew.restype = ctypes.py_object
- cellnew.argtypes = (ctypes.py_object,)
- dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures)))
+ closure = _reconstruct_closure(closures) if closures else None
if base_globals is None:
base_globals = {}
base_globals['__builtins__'] = __builtins__
return types.FunctionType(code, base_globals,
- None, None, dummy_closure)
-
-# this piece of opaque code is needed below to modify 'cell' contents
-cell_changer_code = new.code(
- 1, 1, 2, 0,
- ''.join([
- chr(dis.opmap['LOAD_FAST']), '\x00\x00',
- chr(dis.opmap['DUP_TOP']),
- chr(dis.opmap['STORE_DEREF']), '\x00\x00',
- chr(dis.opmap['RETURN_VALUE'])
- ]),
- (), (), ('newval',), '<nowhere>', 'cell_changer', 1, '', ('c',), ()
-)
-
-def _change_cell_value(cell, newval):
- """ Changes the contents of 'cell' object to newval """
- return new.function(cell_changer_code, {}, None, (), (cell,))(newval)
+ None, None, closure)
+
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 22ab8d30c0..15445abf67 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -42,10 +42,6 @@ def worker(sock):
"""
Called by a worker process after the fork().
"""
- # Redirect stdout to stderr
- os.dup2(2, 1)
- sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
-
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
@@ -102,6 +98,7 @@ def manager():
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)
+ sys.stdout.flush()
def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
@@ -115,7 +112,6 @@ def manager():
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
# Initialization complete
- sys.stdout.close()
try:
while True:
try:
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 7b2710b913..a5f9341e81 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -355,7 +355,8 @@ class PickleSerializer(FramedSerializer):
def dumps(self, obj):
return cPickle.dumps(obj, 2)
- loads = cPickle.loads
+ def loads(self, obj):
+ return cPickle.loads(obj)
class CloudPickleSerializer(PickleSerializer):
@@ -374,8 +375,11 @@ class MarshalSerializer(FramedSerializer):
This serializer is faster than PickleSerializer but supports fewer datatypes.
"""
- dumps = marshal.dumps
- loads = marshal.loads
+ def dumps(self, obj):
+ return marshal.dumps(obj)
+
+ def loads(self, obj):
+ return marshal.loads(obj)
class AutoSerializer(FramedSerializer):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bb84ebe72c..2e7c2750a8 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
import time
import zipfile
import random
+from platform import python_implementation
if sys.version_info[:2] <= (2, 6):
import unittest2 as unittest
@@ -41,7 +42,8 @@ else:
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
+from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
+ CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
@@ -168,15 +170,46 @@ class SerializationTestCase(unittest.TestCase):
p2 = loads(dumps(p1, 2))
self.assertEquals(p1, p2)
-
-# Regression test for SPARK-3415
-class CloudPickleTest(unittest.TestCase):
+ def test_itemgetter(self):
+ from operator import itemgetter
+ ser = CloudPickleSerializer()
+ d = range(10)
+ getter = itemgetter(1)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ getter = itemgetter(0, 3)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ def test_attrgetter(self):
+ from operator import attrgetter
+ ser = CloudPickleSerializer()
+
+ class C(object):
+ def __getattr__(self, item):
+ return item
+ d = C()
+ getter = attrgetter("a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("a", "b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ d.e = C()
+ getter = attrgetter("e.a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("e.a", "e.b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ # Regression test for SPARK-3415
def test_pickling_file_handles(self):
- from pyspark.cloudpickle import dumps
- from StringIO import StringIO
- from pickle import load
+ ser = CloudPickleSerializer()
out1 = sys.stderr
- out2 = load(StringIO(dumps(out1)))
+ out2 = ser.loads(ser.dumps(out1))
self.assertEquals(out1, out2)
@@ -861,9 +894,43 @@ class TestOutputFormat(PySparkTestCase):
conf=input_conf).collect())
self.assertEqual(old_dataset, dict_data)
- @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed")
def test_newhadoop(self):
basepath = self.tempdir.name
+ data = [(1, ""),
+ (1, "a"),
+ (2, "bcdf")]
+ self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text")
+ result = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ self.assertEqual(result, data)
+
+ conf = {
+ "mapreduce.outputformat.class":
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapred.output.value.class": "org.apache.hadoop.io.Text",
+ "mapred.output.dir": basepath + "/newdataset/"
+ }
+ self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
+ input_conf = {"mapred.input.dir": basepath + "/newdataset/"}
+ new_dataset = sorted(self.sc.newAPIHadoopRDD(
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text",
+ conf=input_conf).collect())
+ self.assertEqual(new_dataset, data)
+
+ @unittest.skipIf(sys.version_info[:2] <= (2, 6) or python_implementation() == "PyPy",
+ "Skipped on 2.6 and PyPy until SPARK-2951 is fixed")
+ def test_newhadoop_with_array(self):
+ basepath = self.tempdir.name
# use custom ArrayWritable types and converters to handle arrays
array_data = [(1, array('d')),
(1, array('d', [1.0, 2.0, 3.0])),
diff --git a/python/run-tests b/python/run-tests
index d98840de59..a67e5a99fb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -85,6 +85,27 @@ run_test "pyspark/mllib/tests.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
+# Try to test with PyPy
+if [ $(which pypy) ]; then
+ export PYSPARK_PYTHON="pypy"
+ echo "Testing with PyPy version:"
+ $PYSPARK_PYTHON --version
+
+ run_test "pyspark/rdd.py"
+ run_test "pyspark/context.py"
+ run_test "pyspark/conf.py"
+ run_test "pyspark/sql.py"
+ # These tests are included in the module-level docs, and so must
+ # be handled on a higher level rather than within the python file.
+ export PYSPARK_DOC_TEST=1
+ run_test "pyspark/broadcast.py"
+ run_test "pyspark/accumulators.py"
+ run_test "pyspark/serializers.py"
+ unset PYSPARK_DOC_TEST
+ run_test "pyspark/shuffle.py"
+ run_test "pyspark/tests.py"
+fi
+
if [[ $FAILED == 0 ]]; then
echo -en "\033[32m" # Green
echo "Tests passed."