From b58340dbd9a741331fc4c3829b08c093560056c2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 14:48:45 -0800 Subject: Rename top-level 'pyspark' directory to 'python' --- python/.gitignore | 2 + python/epydoc.conf | 19 + python/examples/kmeans.py | 52 ++ python/examples/logistic_regression.py | 57 ++ python/examples/pi.py | 21 + python/examples/transitive_closure.py | 50 ++ python/examples/wordcount.py | 19 + python/lib/PY4J_LICENSE.txt | 27 + python/lib/PY4J_VERSION.txt | 1 + python/lib/py4j0.7.egg | Bin 0 -> 191756 bytes python/lib/py4j0.7.jar | Bin 0 -> 103286 bytes python/pyspark/__init__.py | 20 + python/pyspark/broadcast.py | 48 ++ python/pyspark/cloudpickle.py | 974 +++++++++++++++++++++++++++++++++ python/pyspark/context.py | 158 ++++++ python/pyspark/java_gateway.py | 38 ++ python/pyspark/join.py | 92 ++++ python/pyspark/rdd.py | 713 ++++++++++++++++++++++++ python/pyspark/serializers.py | 78 +++ python/pyspark/shell.py | 33 ++ python/pyspark/worker.py | 40 ++ 21 files changed, 2442 insertions(+) create mode 100644 python/.gitignore create mode 100644 python/epydoc.conf create mode 100644 python/examples/kmeans.py create mode 100755 python/examples/logistic_regression.py create mode 100644 python/examples/pi.py create mode 100644 python/examples/transitive_closure.py create mode 100644 python/examples/wordcount.py create mode 100644 python/lib/PY4J_LICENSE.txt create mode 100644 python/lib/PY4J_VERSION.txt create mode 100644 python/lib/py4j0.7.egg create mode 100644 python/lib/py4j0.7.jar create mode 100644 python/pyspark/__init__.py create mode 100644 python/pyspark/broadcast.py create mode 100644 python/pyspark/cloudpickle.py create mode 100644 python/pyspark/context.py create mode 100644 python/pyspark/java_gateway.py create mode 100644 python/pyspark/join.py create mode 100644 python/pyspark/rdd.py create mode 100644 python/pyspark/serializers.py create mode 100644 python/pyspark/shell.py create mode 100644 python/pyspark/worker.py (limited to 'python') diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000000..5c56e638f9 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,2 @@ +*.pyc +docs/ diff --git a/python/epydoc.conf b/python/epydoc.conf new file mode 100644 index 0000000000..91ac984ba2 --- /dev/null +++ b/python/epydoc.conf @@ -0,0 +1,19 @@ +[epydoc] # Epydoc section marker (required by ConfigParser) + +# Information about the project. +name: PySpark +url: http://spark-project.org + +# The list of modules to document. Modules can be named using +# dotted names, module filenames, or package directory names. +# This option may be repeated. +modules: pyspark + +# Write html output to the directory "apidocs" +output: html +target: docs/ + +private: no + +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers + pyspark.java_gateway pyspark.examples pyspark.shell diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py new file mode 100644 index 0000000000..ad2be21178 --- /dev/null +++ b/python/examples/kmeans.py @@ -0,0 +1,52 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +import sys + +import numpy as np +from pyspark import SparkContext + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = np.sum((p - centers[i]) ** 2) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.map( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() + + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py new file mode 100755 index 0000000000..f13698a86f --- /dev/null +++ b/python/examples/logistic_regression.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) diff --git a/python/examples/pi.py b/python/examples/pi.py new file mode 100644 index 0000000000..127cba029b --- /dev/null +++ b/python/examples/pi.py @@ -0,0 +1,21 @@ +import sys +from random import random +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonPi") + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py new file mode 100644 index 0000000000..73f7f8fbaf --- /dev/null +++ b/python/examples/transitive_closure.py @@ -0,0 +1,50 @@ +import sys +from random import Random + +from pyspark import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonTC") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelize(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.map(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py new file mode 100644 index 0000000000..857160624b --- /dev/null +++ b/python/examples/wordcount.py @@ -0,0 +1,19 @@ +import sys +from operator import add + +from pyspark import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt new file mode 100644 index 0000000000..a70279ca14 --- /dev/null +++ b/python/lib/PY4J_LICENSE.txt @@ -0,0 +1,27 @@ + +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt new file mode 100644 index 0000000000..04a0cd52a8 --- /dev/null +++ b/python/lib/PY4J_VERSION.txt @@ -0,0 +1 @@ +b7924aabe9c5e63f0a4d8bbd17019534c7ec014e diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg new file mode 100644 index 0000000000..f8a339d8ee Binary files /dev/null and b/python/lib/py4j0.7.egg differ diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar new file mode 100644 index 0000000000..73b7ddb7d1 Binary files /dev/null and b/python/lib/py4j0.7.jar differ diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py new file mode 100644 index 0000000000..c595ae0842 --- /dev/null +++ b/python/pyspark/__init__.py @@ -0,0 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) + + +from pyspark.context import SparkContext +from pyspark.rdd import RDD + + +__all__ = ["SparkContext", "RDD"] diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py new file mode 100644 index 0000000000..93876fa738 --- /dev/null +++ b/python/pyspark/broadcast.py @@ -0,0 +1,48 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.bid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + +>>> large_broadcast = sc.broadcast(list(range(10000))) +""" +# Holds broadcasted data received from Java, keyed by its id. +_broadcastRegistry = {} + + +def _from_id(bid): + from pyspark.broadcast import _broadcastRegistry + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] + + +class Broadcast(object): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.bid = bid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_id, (self.bid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py new file mode 100644 index 0000000000..6a7c23a069 --- /dev/null +++ b/python/pyspark/cloudpickle.py @@ -0,0 +1,974 @@ +""" +This class is defined to override standard pickle functionality + +The goals of it follow: +-Serialize lambdas and nested functions to compiled byte code +-Deal with main module correctly +-Deal with other non-serializable objects + +It does not include an unpickler, as standard python unpickling suffices. + +This module was extracted from the `cloud` package, developed by `PiCloud, Inc. +`_. + +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import operator +import os +import pickle +import struct +import sys +import types +from functools import partial +import itertools +from copy_reg import _extension_registry, _inverted_registry, _extension_cache +import new +import dis +import traceback + +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) + +import logging +cloudLog = logging.getLogger("Cloud.Transport") + +try: + import ctypes +except (MemoryError, ImportError): + logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) + ctypes = None + PyObject_HEAD = None +else: + + # for reading internal structures + PyObject_HEAD = [ + ('ob_refcnt', ctypes.c_size_t), + ('ob_type', ctypes.c_void_p), + ] + + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# These helper functions were copied from PiCloud's util module. +def islambda(func): + return getattr(func,'func_name') == '' + +def xrange_params(xrangeobj): + """Returns a 3 element tuple describing the xrange start, step, and len + respectively + + Note: Only guarentees that elements of xrange are the same. parameters may + be different. + e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same + though w/ iteration + """ + + xrange_len = len(xrangeobj) + if not xrange_len: #empty + return (0,1,0) + start = xrangeobj[0] + if xrange_len == 1: #one element + return start, 1, 1 + return (start, xrangeobj[1] - xrangeobj[0], xrange_len) + +#debug variables intended for developer use: +printSerialization = False +printMemoization = False + +useForcedImports = True #Should I use forced imports for tracking? + + + +class CloudPickler(pickle.Pickler): + + dispatch = pickle.Pickler.dispatch.copy() + savedForceImports = False + savedDjangoEnv = False #hack tro transport django environment + + def __init__(self, file, protocol=None, min_size_to_save= 0): + pickle.Pickler.__init__(self,file,protocol) + self.modules = set() #set of modules needed to depickle + self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + + def dump(self, obj): + # note: not thread safe + # minimal side-effects, so not fixing + recurse_limit = 3000 + base_recurse = sys.getrecursionlimit() + if base_recurse < recurse_limit: + sys.setrecursionlimit(recurse_limit) + self.inject_addons() + try: + return pickle.Pickler.dump(self, obj) + except RuntimeError, e: + if 'recursion' in e.args[0]: + msg = """Could not pickle object as excessively deep recursion required. + Try _fast_serialization=2 or contact PiCloud support""" + raise pickle.PicklingError(msg) + finally: + new_recurse = sys.getrecursionlimit() + if new_recurse == recurse_limit: + sys.setrecursionlimit(base_recurse) + + def save_buffer(self, obj): + """Fallback to save_string""" + pickle.Pickler.save_string(self,str(obj)) + dispatch[buffer] = save_buffer + + #block broken objects + def save_unsupported(self, obj, pack=None): + raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) + dispatch[types.GeneratorType] = save_unsupported + + #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it + try: + slice(0,1).__reduce__() + except TypeError: #can't pickle - + dispatch[slice] = save_unsupported + + #itertools objects do not pickle! + for v in itertools.__dict__.values(): + if type(v) is type: + dispatch[v] = save_unsupported + + + def save_dict(self, obj): + """hack fix + If the dict is a global, deal with it in a special way + """ + #print 'saving', obj + if obj is __builtins__: + self.save_reduce(_get_module_builtins, (), obj=obj) + else: + pickle.Pickler.save_dict(self, obj) + dispatch[pickle.DictionaryType] = save_dict + + + def save_module(self, obj, pack=struct.pack): + """ + Save a module as an import + """ + #print 'try save import', obj.__name__ + self.modules.add(obj) + self.save_reduce(subimport,(obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module #new type + + def save_codeobject(self, obj, pack=struct.pack): + """ + Save a code object + """ + #print 'try to save codeobj: ', obj + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) + self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject #new type + + def save_function(self, obj, name=None, pack=struct.pack): + """ Registered with the dispatch to handle all function types. + + Determines what kind of function obj is (e.g. lambda, defined at + interactive prompt, etc) and handles the pickling appropriately. + """ + write = self.write + + name = obj.__name__ + modname = pickle.whichmodule(obj, name) + #print 'which gives %s %s %s' % (modname, obj, name) + try: + themodule = sys.modules[modname] + except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + modname = '__main__' + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + if not self.savedDjangoEnv: + #hack for django - if we detect the settings module, we transport it + django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') + if django_settings: + django_mod = sys.modules.get(django_settings) + if django_mod: + cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) + self.savedDjangoEnv = True + self.modules.add(django_mod) + write(pickle.MARK) + self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) + write(pickle.POP_MARK) + + + # if func is lambda, def'ed at prompt, is in main, or is nested, then + # we'll pickle the actual function object rather than simply saving a + # reference (as is done in default pickler), via save_function_tuple. + if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + #Force server to import modules that have been imported in main + modList = None + if themodule == None and not self.savedForceImports: + mainmod = sys.modules['__main__'] + if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): + modList = list(mainmod.___pyc_forcedImports__) + self.savedForceImports = True + self.save_function_tuple(obj, modList) + return + else: # func is nested + klass = getattr(themodule, name, None) + if klass is None or klass is not obj: + self.save_function_tuple(obj, [themodule]) + return + + if obj.__dict__: + # essentially save_reduce, but workaround needed to avoid recursion + self.save(_restore_attr) + write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + self.save(obj.__dict__) + write(pickle.TUPLE + pickle.REDUCE) + else: + write(pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + dispatch[types.FunctionType] = save_function + + def save_function_tuple(self, func, forced_imports): + """ Pickles an actual func object. + + A func comprises: code, globals, defaults, closure, and dict. We + extract and save these, injecting reducing functions at certain points + to recreate the func object. Keep in mind that some of these pieces + can contain a ref to the func itself. Thus, a naive save on these + pieces could trigger an infinite loop of save's. To get around that, + we first create a skeleton func object using just the code (this is + safe, since this won't contain a ref to the func), and memoize it as + soon as it's created. The other stuff can then be filled in later. + """ + save = self.save + write = self.write + + # save the modules (if any) + if forced_imports: + write(pickle.MARK) + save(_modules_to_main) + #print 'forced imports are', forced_imports + + forced_names = map(lambda m: m.__name__, forced_imports) + save((forced_names,)) + + #save((forced_imports,)) + write(pickle.REDUCE) + write(pickle.POP_MARK) + + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + save(_fill_function) # skeleton function updater + write(pickle.MARK) # beginning of tuple that _fill_function expects + + # create a skeleton function object and memoize it + save(_make_skel_func) + save((code, len(closure), base_globals)) + write(pickle.REDUCE) + self.memoize(func) + + # save the rest of the func data needed by _fill_function + save(f_globals) + save(defaults) + save(closure) + save(dct) + write(pickle.TUPLE) + write(pickle.REDUCE) # applies _fill_function on the tuple + + @staticmethod + def extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + return out_names + + def extract_func_data(self, func): + """ + Turn the function into a tuple of data necessary to recreate it: + code, globals, defaults, closure, dict + """ + code = func.func_code + + # extract all global ref's + func_global_refs = CloudPickler.extract_code_globals(code) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType and const.co_names: + func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment + f_globals = {} + for var in func_global_refs: + #Some names, such as class functions are not global - we don't need them + if func.func_globals.has_key(var): + f_globals[var] = func.func_globals[var] + + # defaults requires no processing + defaults = func.func_defaults + + def get_contents(cell): + try: + return cell.cell_contents + except ValueError, e: #cell is empty error on not yet assigned + raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') + + + # process closure + if func.func_closure: + closure = map(get_contents, func.func_closure) + else: + closure = [] + + # save the dict + dct = func.func_dict + + if printSerialization: + outvars = ['code: ' + str(code) ] + outvars.append('globals: ' + str(f_globals)) + outvars.append('defaults: ' + str(defaults)) + outvars.append('closure: ' + str(closure)) + print 'function ', func, 'is extracted to: ', ', '.join(outvars) + + base_globals = self.globals_ref.get(id(func.func_globals), {}) + self.globals_ref[id(func.func_globals)] = base_globals + + return (code, f_globals, defaults, closure, dct, base_globals) + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + if name is None: + name = obj.__name__ + + modname = getattr(obj, "__module__", None) + if modname is None: + modname = pickle.whichmodule(obj, name) + + try: + __import__(modname) + themodule = sys.modules[modname] + except (ImportError, KeyError, AttributeError): #should never occur + raise pickle.PicklingError( + "Can't pickle %r: Module %s cannot be found" % + (obj, modname)) + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + sendRef = True + typ = type(obj) + #print 'saving', obj, typ + try: + try: #Deal with case when getattribute fails with exceptions + klass = getattr(themodule, name) + except (AttributeError): + if modname == '__builtin__': #new.* are misrepeported + modname = 'new' + __import__(modname) + themodule = sys.modules[modname] + try: + klass = getattr(themodule, name) + except AttributeError, a: + #print themodule, name, obj, type(obj) + raise pickle.PicklingError("Can't pickle builtin %s" % obj) + else: + raise + + except (ImportError, KeyError, AttributeError): + if typ == types.TypeType or typ == types.ClassType: + sendRef = False + else: #we can't deal with this + raise + else: + if klass is not obj and (typ == types.TypeType or typ == types.ClassType): + sendRef = False + if not sendRef: + #note: Third party types might crash this - add better checks! + d = dict(obj.__dict__) #copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties + d.pop('__dict__',None) + d.pop('__weakref__',None) + + # hack as __new__ is stored differently in the __dict__ + new_override = d.get('__new__', None) + if new_override: + d['__new__'] = obj.__new__ + + self.save_reduce(type(obj),(obj.__name__,obj.__bases__, + d),obj=obj) + #print 'internal reduce dask %s %s' % (obj, d) + return + + if self.proto >= 2: + code = _extension_registry.get((modname, name)) + if code: + assert code > 0 + if code <= 0xff: + write(pickle.EXT1 + chr(code)) + elif code <= 0xffff: + write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) + else: + write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": + #Added fix to allow transient + cls = args[0] + if not hasattr(cls, "__new__"): + raise pickle.PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise pickle.PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + + #Don't pickle transient entries + if hasattr(obj, '__transient__'): + transient = obj.__transient__ + state = state.copy() + + for k in list(state.keys()): + if k in transient: + del state[k] + + save(args) + write(pickle.NEWOBJ) + else: + save(func) + save(args) + write(pickle.REDUCE) + + if obj is not None: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + #print 'obj %s has state %s' % (obj, state) + save(state) + write(pickle.BUILD) + + + def save_xrange(self, obj): + """Save an xrange object in python 2.5 + Python 2.6 supports this natively + """ + range_params = xrange_params(obj) + self.save_reduce(_build_xrange,range_params) + + #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it + try: + xrange(0).__reduce__() + except TypeError: #can't pickle -- use PiCloud pickler + dispatch[xrange] = save_xrange + + def save_partial(self, obj): + """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" + self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) + + if sys.version_info < (2,7): #2.7 supports partial pickling + dispatch[partial] = save_partial + + + def save_file(self, obj): + """Save a file""" + import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + from ..transport.adapter import SerializingAdapter + + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") + if obj.name == '': + return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + if obj.name == '': + return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + if obj.name == '': + raise pickle.PicklingError("Cannot pickle standard input") + if hasattr(obj, 'isatty') and obj.isatty(): + raise pickle.PicklingError("Cannot pickle files that map to tty objects") + if 'r' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + name = obj.name + try: + fsize = os.stat(name).st_size + except OSError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) + + if obj.closed: + #create an empty closed string io + retval = pystringIO.StringIO("") + retval.close() + elif not fsize: #empty file + retval = pystringIO.StringIO("") + try: + tmpfile = file(name) + tst = tmpfile.read(1) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + tmpfile.close() + if tst != '': + raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) + elif fsize > SerializingAdapter.max_transmit_data: + raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % + (name,SerializingAdapter.max_transmit_data)) + else: + try: + tmpfile = file(name) + contents = tmpfile.read(SerializingAdapter.max_transmit_data) + tmpfile.close() + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval = pystringIO.StringIO(contents) + curloc = obj.tell() + retval.seek(curloc) + + retval.name = name + self.save(retval) #save stringIO + self.memoize(obj) + + dispatch[file] = save_file + """Special functions for Add-on libraries""" + + def inject_numpy(self): + numpy = sys.modules.get('numpy') + if not numpy or not hasattr(numpy, 'ufunc'): + return + self.dispatch[numpy.ufunc] = self.__class__.save_ufunc + + numpy_tst_mods = ['numpy', 'scipy.special'] + def save_ufunc(self, obj): + """Hack function for saving numpy ufunc objects""" + name = obj.__name__ + for tst_mod_name in self.numpy_tst_mods: + tst_mod = sys.modules.get(tst_mod_name, None) + if tst_mod: + if name in tst_mod.__dict__: + self.save_reduce(_getobject, (tst_mod_name, name)) + return + raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) + + def inject_timeseries(self): + """Handle bugs with pickling scikits timeseries""" + tseries = sys.modules.get('scikits.timeseries.tseries') + if not tseries or not hasattr(tseries, 'Timeseries'): + return + self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries + + def save_timeseries(self, obj): + import scikits.timeseries.tseries as ts + + func, reduce_args, state = obj.__reduce__() + if func != ts._tsreconstruct: + raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) + state = (1, + obj.shape, + obj.dtype, + obj.flags.fnc, + obj._data.tostring(), + ts.getmaskarray(obj).tostring(), + obj._fill_value, + obj._dates.shape, + obj._dates.__array__().tostring(), + obj._dates.dtype, #added -- preserve type + obj.freq, + obj._optinfo, + ) + return self.save_reduce(_genTimeSeries, (reduce_args, state)) + + def inject_email(self): + """Block email LazyImporters from being saved""" + email = sys.modules.get('email') + if not email: + return + self.dispatch[email.LazyImporter] = self.__class__.save_unsupported + + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + self.inject_numpy() + self.inject_timeseries() + self.inject_email() + + """Python Imaging Library""" + def save_image(self, obj): + if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ + and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): + #if image not loaded yet -- lazy load + self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) + else: + #image is loaded - just transmit it over + self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) + + """ + def memoize(self, obj): + pickle.Pickler.memoize(self, obj) + if printMemoization: + print 'memoizing ' + str(obj) + """ + + + +# Shorthands for legacy support + +def dump(obj, file, protocol=2): + CloudPickler(file, protocol).dump(obj) + +def dumps(obj, protocol=2): + file = StringIO() + + cp = CloudPickler(file,protocol) + cp.dump(obj) + + #print 'cloud dumped', str(obj), str(cp.modules) + + return file.getvalue() + + +#hack for __import__ not working as desired +def subimport(name): + __import__(name) + return sys.modules[name] + +#hack to load django settings: +def django_settings_load(name): + modified_env = False + + if 'DJANGO_SETTINGS_MODULE' not in os.environ: + os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps + modified_env = True + try: + module = subimport(name) + except Exception, i: + print >> sys.stderr, 'Cloud not import django settings %s:' % (name) + print_exec(sys.stderr) + if modified_env: + del os.environ['DJANGO_SETTINGS_MODULE'] + else: + #add project directory to sys,path: + if hasattr(module,'__file__'): + dirname = os.path.split(module.__file__)[0] + '/' + sys.path.append(dirname) + +# restores function attributes +def _restore_attr(obj, attr): + for key, val in attr.items(): + setattr(obj, key, val) + return obj + +def _get_module_builtins(): + return pickle.__builtins__ + +def print_exec(stream): + ei = sys.exc_info() + traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + +def _modules_to_main(modList): + """Force every module in modList to be placed into main""" + if not modList: + return + + main = sys.modules['__main__'] + for modname in modList: + if type(modname) is str: + try: + mod = __import__(modname) + except Exception, i: #catch all... + sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ +A version mismatch is likely. Specific error was:\n' % modname) + print_exec(sys.stderr) + else: + setattr(main,mod.__name__, mod) + else: + #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) + #In old version actual module was sent + setattr(main,modname.__name__, modname) + +#object generators: +def _build_xrange(start, step, len): + """Built xrange explicitly""" + return xrange(start, start + step*len, step) + +def _genpartial(func, args, kwds): + if not args: + args = () + if not kwds: + kwds = {} + return partial(func, *args, **kwds) + + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(). + """ + func.func_globals.update(globals) + func.func_defaults = defaults + func.func_dict = dict + + if len(closure) != len(func.func_closure): + raise pickle.UnpicklingError("closure lengths don't match up") + for i in range(len(closure)): + _change_cell_value(func.func_closure[i], closure[i]) + + return func + +def _make_skel_func(code, num_closures, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + #build closure (cells): + if not ctypes: + raise Exception('ctypes failed to import; cannot build function') + + cellnew = ctypes.pythonapi.PyCell_New + cellnew.restype = ctypes.py_object + cellnew.argtypes = (ctypes.py_object,) + dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + + if base_globals is None: + base_globals = {} + base_globals['__builtins__'] = __builtins__ + + return types.FunctionType(code, base_globals, + None, None, dummy_closure) + +# this piece of opaque code is needed below to modify 'cell' contents +cell_changer_code = new.code( + 1, 1, 2, 0, + ''.join([ + chr(dis.opmap['LOAD_FAST']), '\x00\x00', + chr(dis.opmap['DUP_TOP']), + chr(dis.opmap['STORE_DEREF']), '\x00\x00', + chr(dis.opmap['RETURN_VALUE']) + ]), + (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () +) + +def _change_cell_value(cell, newval): + """ Changes the contents of 'cell' object to newval """ + return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + +"""Constructors for 3rd party libraries +Note: These can never be renamed due to client compatibility issues""" + +def _getobject(modname, attribute): + mod = __import__(modname) + return mod.__dict__[attribute] + +def _generateImage(size, mode, str_rep): + """Generate image from string representation""" + import Image + i = Image.new(mode, size) + i.fromstring(str_rep) + return i + +def _lazyloadImage(fp): + import Image + fp.seek(0) #works in almost any case + return Image.open(fp) + +"""Timeseries""" +def _genTimeSeries(reduce_args, state): + import scikits.timeseries.tseries as ts + from numpy import ndarray + from numpy.ma import MaskedArray + + + time_series = ts._tsreconstruct(*reduce_args) + + #from setstate modified + (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state + #print 'regenerating %s' % dtyp + + MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) + _dates = time_series._dates + #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ + ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) + _dates.freq = frq + _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, + toobj=None, toord=None, tostr=None)) + # Update the _optinfo dictionary + time_series._optinfo.update(infodict) + return time_series + diff --git a/python/pyspark/context.py b/python/pyspark/context.py new file mode 100644 index 0000000000..6172d69dcf --- /dev/null +++ b/python/pyspark/context.py @@ -0,0 +1,158 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.broadcast import Broadcast +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.rdd import RDD + +from py4j.java_collections import ListConverter + + +class SparkContext(object): + """ + Main entry point for Spark functionality. A SparkContext represents the + connection to a Spark cluster, and can be used to create L{RDD}s and + broadcast variables on that cluster. + """ + + gateway = launch_gateway() + jvm = gateway.jvm + _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + + def __init__(self, master, jobName, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024): + """ + Create a new SparkContext. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param jobName: A name for your job, to display on the cluster web UI + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + """ + self.master = master + self.jobName = jobName + self.sparkHome = sparkHome or None # None becomes null in Py4J + self.environment = environment or {} + self.batchSize = batchSize # -1 represents a unlimited batch size + + # Create the Java SparkContext through Py4J + empty_string_array = self.gateway.new_array(self.jvm.String, 0) + self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array) + + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() + + # Deploy any code dependencies specified in the constructor + for path in (pyFiles or []): + self.addPyFile(path) + + @property + def defaultParallelism(self): + """ + Default level of parallelism to use when not given by user (e.g. for + reduce tasks) + """ + return self._jsc.sc().defaultParallelism() + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + """ + Shut down the SparkContext. + """ + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None): + """ + Distribute a local Python collection to form an RDD. + """ + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) + for x in c: + write_with_length(dump_pickle(x), tempFile) + tempFile.close() + jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) + + def textFile(self, name, minSplits=None): + """ + Read a text file from HDFS, a local file system (available on all + nodes), or any Hadoop-supported file system URI, and return it as an + RDD of Strings. + """ + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) + return RDD(jrdd, self) + + def union(self, rdds): + """ + Build the union of a list of RDDs. + """ + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + + def broadcast(self, value): + """ + Broadcast a read-only variable to the cluster, returning a C{Broadcast} + object for reading it in distributed functions. The variable will be + sent to each cluster only once. + """ + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + return Broadcast(jbroadcast.id(), value, jbroadcast, + self._pickled_broadcast_vars) + + def addFile(self, path): + """ + Add a file to be downloaded into the working directory of this Spark + job on every node. The C{path} passed can be either a local file, + a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, + HTTPS or FTP URI. + """ + self._jsc.sc().addFile(path) + + def clearFiles(self): + """ + Clear the job's list of files added by L{addFile} or L{addPyFile} so + that they do not get downloaded to any new nodes. + """ + # TODO: remove added .py or .zip files from the PYTHONPATH? + self._jsc.sc().clearFiles() + + def addPyFile(self, path): + """ + Add a .py or .zip dependency for all tasks to be executed on this + SparkContext in the future. The C{path} passed can be either a local + file, a file in HDFS (or other Hadoop-supported filesystems), or an + HTTP, HTTPS or FTP URI. + """ + self.addFile(path) + filename = path.split("/")[-1] + os.environ["PYTHONPATH"] = \ + "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py new file mode 100644 index 0000000000..2329e536cc --- /dev/null +++ b/python/pyspark/java_gateway.py @@ -0,0 +1,38 @@ +import os +import sys +from subprocess import Popen, PIPE +from threading import Thread +from py4j.java_gateway import java_import, JavaGateway, GatewayClient + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def launch_gateway(): + # Launch the Py4j gateway using Spark's run command so that we pick up the + # proper classpath and SPARK_MEM settings from spark-env.sh + command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer", + "--die-on-broken-pipe", "0"] + proc = Popen(command, stdout=PIPE, stdin=PIPE) + # Determine which ephemeral port the server started on: + port = int(proc.stdout.readline()) + # Create a thread to echo output from the GatewayServer, which is required + # for Java log output to show up: + class EchoOutputThread(Thread): + def __init__(self, stream): + Thread.__init__(self) + self.daemon = True + self.stream = stream + + def run(self): + while True: + line = self.stream.readline() + sys.stderr.write(line) + EchoOutputThread(proc.stdout).start() + # Connect to the gateway + gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) + # Import the classes used by PySpark + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + return gateway diff --git a/python/pyspark/join.py b/python/pyspark/join.py new file mode 100644 index 0000000000..7036c47980 --- /dev/null +++ b/python/pyspark/join.py @@ -0,0 +1,92 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numSplits, dispatch) + + +def python_cogroup(rdd, other, numSplits): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py new file mode 100644 index 0000000000..cbffb6cc1f --- /dev/null +++ b/python/pyspark/rdd.py @@ -0,0 +1,713 @@ +import atexit +from base64 import standard_b64encode as b64enc +import copy +from collections import defaultdict +from itertools import chain, ifilter, imap, product +import operator +import os +import shlex +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread + +from pyspark import cloudpickle +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + +from py4j.java_collections import ListConverter, MapConverter + + +__all__ = ["RDD"] + + +class RDD(object): + """ + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + Represents an immutable, partitioned collection of elements that can be + operated on in parallel. + """ + + def __init__(self, jrdd, ctx): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + + @property + def context(self): + """ + The L{SparkContext} that this RDD was created on. + """ + return self.ctx + + def cache(self): + """ + Persist this RDD with the default storage level (C{MEMORY_ONLY}). + """ + self.is_cached = True + self._jrdd.cache() + return self + + # TODO persist(self, storageLevel) + + def map(self, f, preservesPartitioning=False): + """ + Return a new RDD containing the distinct elements in this RDD. + """ + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new RDD by first applying a function to all elements of this + RDD, and then flattening the results. + + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> def f(iterator): yield sum(iterator) + >>> rdd.mapPartitions(f).collect() + [3, 7] + """ + return PipelinedRDD(self, f, preservesPartitioning) + + # TODO: mapPartitionsWithSplit + + def filter(self, f): + """ + Return a new RDD containing only the elements that satisfy a predicate. + + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + def func(iterator): return ifilter(f, iterator) + return self.mapPartitions(func) + + def distinct(self): + """ + Return a new RDD containing the distinct elements in this RDD. + + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + return self.map(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + # TODO: sampling needs to be re-implemented due to Batch + #def sample(self, withReplacement, fraction, seed): + # jrdd = self._jrdd.sample(withReplacement, fraction, seed) + # return RDD(jrdd, self.ctx) + + #def takeSample(self, withReplacement, num, seed): + # vals = self._jrdd.takeSample(withReplacement, num, seed) + # return [load_pickle(bytes(x)) for x in vals] + + def union(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return RDD(self._jrdd.union(other._jrdd), self.ctx) + + def __add__(self, other): + """ + Return the union of this RDD and another one. + + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> (rdd + rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + if not isinstance(other, RDD): + raise TypeError + return self.union(other) + + # TODO: sort + + def glom(self): + """ + Return an RDD created by coalescing all elements within each partition + into a list. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> sorted(rdd.glom().collect()) + [[1, 2], [3, 4]] + """ + def func(iterator): yield list(iterator) + return self.mapPartitions(func) + + def cartesian(self, other): + """ + Return the Cartesian product of this RDD and another one, that is, the + RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and + C{b} is in C{other}. + + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) + + def groupBy(self, f, numSplits=None): + """ + Return an RDD of grouped items. + + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) + + def pipe(self, command, env={}): + """ + Return an RDD created by piping elements to a forked external process. + + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) + + def foreach(self, f): + """ + Applies a function to all elements of this RDD. + + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + """ + Return a list that contains all of the elements in this RDD. + """ + picklesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(picklesInJava)) + + def _collect_iterator_through_file(self, iterator): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) + + def reduce(self, f): + """ + Reduces the elements of this RDD using the specified associative binary + operator. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) + 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 + """ + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = self.mapPartitions(func).collect() + return reduce(f, vals) + + def fold(self, zeroValue, op): + """ + Aggregate the elements of each partition, and then the results for all + the partitions, using a given associative function and a neutral "zero + value." + + The function C{op(t1, t2)} is allowed to modify C{t1} and return it + as its result value to avoid object allocation; however, it should not + modify C{t2}. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) + 15 + """ + def func(iterator): + acc = zeroValue + for obj in iterator: + acc = op(obj, acc) + yield acc + vals = self.mapPartitions(func).collect() + return reduce(op, vals, zeroValue) + + # TODO: aggregate + + def sum(self): + """ + Add up the elements in this RDD. + + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + + def count(self): + """ + Return the number of elements in this RDD. + + >>> sc.parallelize([2, 3, 4]).count() + 3 + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + + def countByValue(self): + """ + Return the count of each unique value in this RDD as a dictionary of + (value, count) pairs. + + >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) + [(1, 2), (2, 3)] + """ + def countPartition(iterator): + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 + return self.mapPartitions(countPartition).reduce(mergeMaps) + + def take(self, num): + """ + Take the first num elements of the RDD. + + This currently scans the partitions *one by one*, so it will be slow if + a lot of partitions are required. In that case, use L{collect} to get + the whole RDD instead. + + >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + [2, 3] + >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) + [2, 3, 4, 5, 6] + """ + items = [] + splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) + while len(items) < num and splits: + split = splits.pop(0) + iterator = self._jrdd.iterator(split, taskContext) + items.extend(self._collect_iterator_through_file(iterator)) + return items[:num] + + def first(self): + """ + Return the first element in this RDD. + + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.take(1)[0] + + def saveAsTextFile(self, path): + """ + Save this RDD as a text file, using string representations of elements. + + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) + >>> from fileinput import input + >>> from glob import glob + >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) + '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' + """ + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) + + # Pair functions + + def collectAsMap(self): + """ + Return the key-value pairs in this RDD to the master as a dictionary. + + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) + + def reduceByKey(self, func, numSplits=None): + """ + Merge the values for each key using an associative reduce function. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + Output will be hash-partitioned with C{numSplits} splits, or the + default parallelism level if C{numSplits} is not specified. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKey(add).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + def reduceByKeyLocally(self, func): + """ + Merge the values for each key using an associative reduce function, but + return the results immediately to the master as a dictionary. + + This will also perform the merging locally on each mapper before + sending results to a reducer, similarly to a "combiner" in MapReduce. + + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKeyLocally(add).items()) + [('a', 2), ('b', 1)] + """ + def reducePartition(iterator): + m = {} + for (k, v) in iterator: + m[k] = v if k not in m else func(m[k], v) + yield m + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] = v if k not in m1 else func(m1[k], v) + return m1 + return self.mapPartitions(reducePartition).reduce(mergeMaps) + + def countByKey(self): + """ + Count the number of elements for each key, and return the result to the + master as a dictionary. + + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.countByKey().items()) + [('a', 2), ('b', 1)] + """ + return self.map(lambda x: x[0]).countByValue() + + def join(self, other, numSplits=None): + """ + Return an RDD containing all pairs of elements with matching keys in + C{self} and C{other}. + + Each pair of elements will be returned as a (k, (v1, v2)) tuple, where + (k, v1) is in C{self} and (k, v2) is in C{other}. + + Performs a hash join across the cluster. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) + [('a', (1, 2)), ('a', (1, 3))] + """ + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + Perform a left outer join of C{self} and C{other}. + + For each element (k, v) in C{self}, the resulting RDD will either + contain all pairs (k, (v, w)) for w in C{other}, or the pair + (k, (v, None)) if no elements in other have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + Perform a right outer join of C{self} and C{other}. + + For each element (k, w) in C{other}, the resulting RDD will either + contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w)) + if no elements in C{self} have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + return python_right_outer_join(self, other, numSplits) + + # TODO: add option to control map-side combining + def partitionBy(self, numSplits, hashFunc=hash): + """ + Return a copy of the RDD partitioned using the specified partitioner. + + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. + def add_shuffle_key(iterator): + buckets = defaultdict(list) + for (k, v) in iterator: + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(Batch(items)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + jrdd = pairRDD.partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + return RDD(jrdd, self.ctx) + + # TODO: add control over map-side aggregation + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None): + """ + Generic function to combine the elements for each key using a custom + set of aggregation functions. + + Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined + type" C. Note that V and C can be different -- for example, one might + group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). + + Users provide three functions: + + - C{createCombiner}, which turns a V into a C (e.g., creates + a one-element list) + - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of + a list) + - C{mergeCombiners}, to combine two C's into a single one. + + In addition, users can control the partitioning of the output RDD. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + if numSplits is None: + numSplits = self.ctx.defaultParallelism + def combineLocally(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numSplits) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) + + # TODO: support variant with custom partitioner + def groupByKey(self, numSplits=None): + """ + Group the values for each key in the RDD into a single sequence. + Hash-partitions the resulting RDD with into numSplits partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + # TODO: add tests + def flatMapValues(self, f): + """ + Pass each value in the key-value pair RDD through a flatMap function + without changing the keys; this also retains the original RDD's + partitioning. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def mapValues(self, f): + """ + Pass each value in the key-value pair RDD through a map function + without changing the keys; this also retains the original RDD's + partitioning. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + """ + Alias for cogroup. + """ + return self.cogroup(other) + + # TODO: add variant with custom parittioner + def cogroup(self, other, numSplits=None): + """ + For each key k in C{self} or C{other}, return a resulting RDD that + contains a tuple with the list of values for that key in C{self} as well + as C{other}. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> sorted(x.cogroup(y).collect()) + [('a', ([1], [2])), ('b', ([4], []))] + """ + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + +class PipelinedRDD(RDD): + """ + Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 + """ + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: + prev_func = prev.func + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + self._bypass_serializer = False + + @property + def _jrdd(self): + if self._jrdd_val: + return self._jrdd_val + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + env = copy.copy(self.ctx.environment) + env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") + env = MapConverter().convert(env, self.ctx.gateway._gateway_client) + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py new file mode 100644 index 0000000000..9a5151ea00 --- /dev/null +++ b/python/pyspark/serializers.py @@ -0,0 +1,78 @@ +import struct +import cPickle + + +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: + yield Batch(items) + + +def dump_pickle(obj): + return cPickle.dumps(obj, 2) + + +load_pickle = cPickle.loads + + +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + +def write_with_length(obj, stream): + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) + + +def read_with_length(stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return obj + + +def read_from_pickle_file(stream): + try: + while True: + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj + except EOFError: + return diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py new file mode 100644 index 0000000000..bd39b0283f --- /dev/null +++ b/python/pyspark/shell.py @@ -0,0 +1,33 @@ +""" +An interactive shell. +""" +import optparse # I prefer argparse, but it's not included with Python < 2.7 +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local', ipython=False): + sc = SparkContext(master, 'PySparkShell') + user_ns = {'sc' : sc} + banner = "Spark context avaiable as sc." + if ipython: + import IPython + IPython.embed(user_ns=user_ns, banner2=banner) + else: + print banner + code.interact(local=user_ns) + + +if __name__ == '__main__': + usage = "usage: %prog [options] master" + parser = optparse.OptionParser(usage=usage) + parser.add_option("-i", "--ipython", help="Run IPython shell", + action="store_true") + (options, args) = parser.parse_args() + if len(sys.argv) > 1: + master = args[0] + else: + master = 'local' + main(master, options.ipython) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py new file mode 100644 index 0000000000..9f6b507dbd --- /dev/null +++ b/python/pyspark/worker.py @@ -0,0 +1,40 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import write_with_length, read_with_length, \ + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) + + +def main(): + num_broadcast_variables = read_int(sys.stdin) + for _ in range(num_broadcast_variables): + bid = read_long(sys.stdin) + value = read_with_length(sys.stdin) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x + else: + dumps = dump_pickle + for obj in func(read_from_pickle_file(sys.stdin)): + write_with_length(dumps(obj), old_stdout) + + +if __name__ == '__main__': + main() -- cgit v1.2.3