aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-01 14:48:45 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-01 15:05:00 -0800
commitb58340dbd9a741331fc4c3829b08c093560056c2 (patch)
tree52b0e94c47892a8f884b2f80a59ccdb1a428b389 /python
parent170e451fbdd308ae77065bd9c0f2bd278abf0cb7 (diff)
downloadspark-b58340dbd9a741331fc4c3829b08c093560056c2.tar.gz
spark-b58340dbd9a741331fc4c3829b08c093560056c2.tar.bz2
spark-b58340dbd9a741331fc4c3829b08c093560056c2.zip
Rename top-level 'pyspark' directory to 'python'
Diffstat (limited to 'python')
-rw-r--r--python/.gitignore2
-rw-r--r--python/epydoc.conf19
-rw-r--r--python/examples/kmeans.py52
-rwxr-xr-xpython/examples/logistic_regression.py57
-rw-r--r--python/examples/pi.py21
-rw-r--r--python/examples/transitive_closure.py50
-rw-r--r--python/examples/wordcount.py19
-rw-r--r--python/lib/PY4J_LICENSE.txt27
-rw-r--r--python/lib/PY4J_VERSION.txt1
-rw-r--r--python/lib/py4j0.7.eggbin0 -> 191756 bytes
-rw-r--r--python/lib/py4j0.7.jarbin0 -> 103286 bytes
-rw-r--r--python/pyspark/__init__.py20
-rw-r--r--python/pyspark/broadcast.py48
-rw-r--r--python/pyspark/cloudpickle.py974
-rw-r--r--python/pyspark/context.py158
-rw-r--r--python/pyspark/java_gateway.py38
-rw-r--r--python/pyspark/join.py92
-rw-r--r--python/pyspark/rdd.py713
-rw-r--r--python/pyspark/serializers.py78
-rw-r--r--python/pyspark/shell.py33
-rw-r--r--python/pyspark/worker.py40
21 files changed, 2442 insertions, 0 deletions
diff --git a/python/.gitignore b/python/.gitignore
new file mode 100644
index 0000000000..5c56e638f9
--- /dev/null
+++ b/python/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+docs/
diff --git a/python/epydoc.conf b/python/epydoc.conf
new file mode 100644
index 0000000000..91ac984ba2
--- /dev/null
+++ b/python/epydoc.conf
@@ -0,0 +1,19 @@
+[epydoc] # Epydoc section marker (required by ConfigParser)
+
+# Information about the project.
+name: PySpark
+url: http://spark-project.org
+
+# The list of modules to document. Modules can be named using
+# dotted names, module filenames, or package directory names.
+# This option may be repeated.
+modules: pyspark
+
+# Write html output to the directory "apidocs"
+output: html
+target: docs/
+
+private: no
+
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+ pyspark.java_gateway pyspark.examples pyspark.shell
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
new file mode 100644
index 0000000000..ad2be21178
--- /dev/null
+++ b/python/examples/kmeans.py
@@ -0,0 +1,52 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = np.sum((p - centers[i]) ** 2)
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ kPoints = data.takeSample(False, K, 34)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.map(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
+ newPoints = pointStats.map(
+ lambda (x, (y, z)): (x, y / z)).collect()
+
+ tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)
diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py
new file mode 100755
index 0000000000..f13698a86f
--- /dev/null
+++ b/python/examples/logistic_regression.py
@@ -0,0 +1,57 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from collections import namedtuple
+from math import exp
+from os.path import realpath
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+N = 100000 # Number of data points
+D = 10 # Number of dimensions
+R = 0.7 # Scaling factor
+ITERATIONS = 5
+np.random.seed(42)
+
+
+DataPoint = namedtuple("DataPoint", ['x', 'y'])
+from lr import DataPoint # So that DataPoint is properly serialized
+
+
+def generateData():
+ def generatePoint(i):
+ y = -1 if i % 2 == 0 else 1
+ x = np.random.normal(size=D) + (y * R)
+ return DataPoint(x, y)
+ return [generatePoint(i) for i in range(N)]
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonLR <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ points = sc.parallelize(generateData(), slices).cache()
+
+ # Initialize w to a random value
+ w = 2 * np.random.ranf(size=D) - 1
+ print "Initial w: " + str(w)
+
+ def add(x, y):
+ x += y
+ return x
+
+ for i in range(1, ITERATIONS + 1):
+ print "On iteration %i" % i
+
+ gradient = points.map(lambda p:
+ (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
+ ).reduce(add)
+ w -= gradient
+
+ print "Final w: " + str(w)
diff --git a/python/examples/pi.py b/python/examples/pi.py
new file mode 100644
index 0000000000..127cba029b
--- /dev/null
+++ b/python/examples/pi.py
@@ -0,0 +1,21 @@
+import sys
+from random import random
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonPi <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonPi")
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ n = 100000 * slices
+ def f(_):
+ x = random() * 2 - 1
+ y = random() * 2 - 1
+ return 1 if x ** 2 + y ** 2 < 1 else 0
+ count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
+ print "Pi is roughly %f" % (4.0 * count / n)
diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py
new file mode 100644
index 0000000000..73f7f8fbaf
--- /dev/null
+++ b/python/examples/transitive_closure.py
@@ -0,0 +1,50 @@
+import sys
+from random import Random
+
+from pyspark import SparkContext
+
+numEdges = 200
+numVertices = 100
+rand = Random(42)
+
+
+def generateGraph():
+ edges = set()
+ while len(edges) < numEdges:
+ src = rand.randrange(0, numEdges)
+ dst = rand.randrange(0, numEdges)
+ if src != dst:
+ edges.add((src, dst))
+ return edges
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonTC <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonTC")
+ slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ tc = sc.parallelize(generateGraph(), slices).cache()
+
+ # Linear transitive closure: each round grows paths by one edge,
+ # by joining the graph's edges with the already-discovered paths.
+ # e.g. join the path (y, z) from the TC with the edge (x, y) from
+ # the graph to obtain the path (x, z).
+
+ # Because join() joins on keys, the edges are stored in reversed order.
+ edges = tc.map(lambda (x, y): (y, x))
+
+ oldCount = 0L
+ nextCount = tc.count()
+ while True:
+ oldCount = nextCount
+ # Perform the join, obtaining an RDD of (y, (z, x)) pairs,
+ # then project the result to obtain the new (x, z) paths.
+ new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
+ tc = tc.union(new_edges).distinct().cache()
+ nextCount = tc.count()
+ if nextCount == oldCount:
+ break
+
+ print "TC has %i edges" % tc.count()
diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py
new file mode 100644
index 0000000000..857160624b
--- /dev/null
+++ b/python/examples/wordcount.py
@@ -0,0 +1,19 @@
+import sys
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 3:
+ print >> sys.stderr, \
+ "Usage: PythonWordCount <master> <file>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonWordCount")
+ lines = sc.textFile(sys.argv[2], 1)
+ counts = lines.flatMap(lambda x: x.split(' ')) \
+ .map(lambda x: (x, 1)) \
+ .reduceByKey(add)
+ output = counts.collect()
+ for (word, count) in output:
+ print "%s : %i" % (word, count)
diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt
new file mode 100644
index 0000000000..a70279ca14
--- /dev/null
+++ b/python/lib/PY4J_LICENSE.txt
@@ -0,0 +1,27 @@
+
+Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+- Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+- Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+- The name of the author may not be used to endorse or promote products
+derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt
new file mode 100644
index 0000000000..04a0cd52a8
--- /dev/null
+++ b/python/lib/PY4J_VERSION.txt
@@ -0,0 +1 @@
+b7924aabe9c5e63f0a4d8bbd17019534c7ec014e
diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg
new file mode 100644
index 0000000000..f8a339d8ee
--- /dev/null
+++ b/python/lib/py4j0.7.egg
Binary files differ
diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar
new file mode 100644
index 0000000000..73b7ddb7d1
--- /dev/null
+++ b/python/lib/py4j0.7.jar
Binary files differ
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
new file mode 100644
index 0000000000..c595ae0842
--- /dev/null
+++ b/python/pyspark/__init__.py
@@ -0,0 +1,20 @@
+"""
+PySpark is a Python API for Spark.
+
+Public classes:
+
+ - L{SparkContext<pyspark.context.SparkContext>}
+ Main entry point for Spark functionality.
+ - L{RDD<pyspark.rdd.RDD>}
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+"""
+import sys
+import os
+sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg"))
+
+
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD
+
+
+__all__ = ["SparkContext", "RDD"]
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
new file mode 100644
index 0000000000..93876fa738
--- /dev/null
+++ b/python/pyspark/broadcast.py
@@ -0,0 +1,48 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.bid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+
+>>> large_broadcast = sc.broadcast(list(range(10000)))
+"""
+# Holds broadcasted data received from Java, keyed by its id.
+_broadcastRegistry = {}
+
+
+def _from_id(bid):
+ from pyspark.broadcast import _broadcastRegistry
+ if bid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % bid)
+ return _broadcastRegistry[bid]
+
+
+class Broadcast(object):
+ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.bid = bid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_id, (self.bid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
new file mode 100644
index 0000000000..6a7c23a069
--- /dev/null
+++ b/python/pyspark/cloudpickle.py
@@ -0,0 +1,974 @@
+"""
+This class is defined to override standard pickle functionality
+
+The goals of it follow:
+-Serialize lambdas and nested functions to compiled byte code
+-Deal with main module correctly
+-Deal with other non-serializable objects
+
+It does not include an unpickler, as standard python unpickling suffices.
+
+This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
+<http://www.picloud.com>`_.
+
+Copyright (c) 2012, Regents of the University of California.
+Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of the University of California, Berkeley nor the
+ names of its contributors may be used to endorse or promote
+ products derived from this software without specific prior written
+ permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+import operator
+import os
+import pickle
+import struct
+import sys
+import types
+from functools import partial
+import itertools
+from copy_reg import _extension_registry, _inverted_registry, _extension_cache
+import new
+import dis
+import traceback
+
+#relevant opcodes
+STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
+DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
+LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
+GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
+
+HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
+EXTENDED_ARG = chr(dis.EXTENDED_ARG)
+
+import logging
+cloudLog = logging.getLogger("Cloud.Transport")
+
+try:
+ import ctypes
+except (MemoryError, ImportError):
+ logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True)
+ ctypes = None
+ PyObject_HEAD = None
+else:
+
+ # for reading internal structures
+ PyObject_HEAD = [
+ ('ob_refcnt', ctypes.c_size_t),
+ ('ob_type', ctypes.c_void_p),
+ ]
+
+
+try:
+ from cStringIO import StringIO
+except ImportError:
+ from StringIO import StringIO
+
+# These helper functions were copied from PiCloud's util module.
+def islambda(func):
+ return getattr(func,'func_name') == '<lambda>'
+
+def xrange_params(xrangeobj):
+ """Returns a 3 element tuple describing the xrange start, step, and len
+ respectively
+
+ Note: Only guarentees that elements of xrange are the same. parameters may
+ be different.
+ e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
+ though w/ iteration
+ """
+
+ xrange_len = len(xrangeobj)
+ if not xrange_len: #empty
+ return (0,1,0)
+ start = xrangeobj[0]
+ if xrange_len == 1: #one element
+ return start, 1, 1
+ return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
+
+#debug variables intended for developer use:
+printSerialization = False
+printMemoization = False
+
+useForcedImports = True #Should I use forced imports for tracking?
+
+
+
+class CloudPickler(pickle.Pickler):
+
+ dispatch = pickle.Pickler.dispatch.copy()
+ savedForceImports = False
+ savedDjangoEnv = False #hack tro transport django environment
+
+ def __init__(self, file, protocol=None, min_size_to_save= 0):
+ pickle.Pickler.__init__(self,file,protocol)
+ self.modules = set() #set of modules needed to depickle
+ self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
+
+ def dump(self, obj):
+ # note: not thread safe
+ # minimal side-effects, so not fixing
+ recurse_limit = 3000
+ base_recurse = sys.getrecursionlimit()
+ if base_recurse < recurse_limit:
+ sys.setrecursionlimit(recurse_limit)
+ self.inject_addons()
+ try:
+ return pickle.Pickler.dump(self, obj)
+ except RuntimeError, e:
+ if 'recursion' in e.args[0]:
+ msg = """Could not pickle object as excessively deep recursion required.
+ Try _fast_serialization=2 or contact PiCloud support"""
+ raise pickle.PicklingError(msg)
+ finally:
+ new_recurse = sys.getrecursionlimit()
+ if new_recurse == recurse_limit:
+ sys.setrecursionlimit(base_recurse)
+
+ def save_buffer(self, obj):
+ """Fallback to save_string"""
+ pickle.Pickler.save_string(self,str(obj))
+ dispatch[buffer] = save_buffer
+
+ #block broken objects
+ def save_unsupported(self, obj, pack=None):
+ raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
+ dispatch[types.GeneratorType] = save_unsupported
+
+ #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
+ try:
+ slice(0,1).__reduce__()
+ except TypeError: #can't pickle -
+ dispatch[slice] = save_unsupported
+
+ #itertools objects do not pickle!
+ for v in itertools.__dict__.values():
+ if type(v) is type:
+ dispatch[v] = save_unsupported
+
+
+ def save_dict(self, obj):
+ """hack fix
+ If the dict is a global, deal with it in a special way
+ """
+ #print 'saving', obj
+ if obj is __builtins__:
+ self.save_reduce(_get_module_builtins, (), obj=obj)
+ else:
+ pickle.Pickler.save_dict(self, obj)
+ dispatch[pickle.DictionaryType] = save_dict
+
+
+ def save_module(self, obj, pack=struct.pack):
+ """
+ Save a module as an import
+ """
+ #print 'try save import', obj.__name__
+ self.modules.add(obj)
+ self.save_reduce(subimport,(obj.__name__,), obj=obj)
+ dispatch[types.ModuleType] = save_module #new type
+
+ def save_codeobject(self, obj, pack=struct.pack):
+ """
+ Save a code object
+ """
+ #print 'try to save codeobj: ', obj
+ args = (
+ obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
+ obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
+ obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
+ )
+ self.save_reduce(types.CodeType, args, obj=obj)
+ dispatch[types.CodeType] = save_codeobject #new type
+
+ def save_function(self, obj, name=None, pack=struct.pack):
+ """ Registered with the dispatch to handle all function types.
+
+ Determines what kind of function obj is (e.g. lambda, defined at
+ interactive prompt, etc) and handles the pickling appropriately.
+ """
+ write = self.write
+
+ name = obj.__name__
+ modname = pickle.whichmodule(obj, name)
+ #print 'which gives %s %s %s' % (modname, obj, name)
+ try:
+ themodule = sys.modules[modname]
+ except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
+ modname = '__main__'
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ if not self.savedDjangoEnv:
+ #hack for django - if we detect the settings module, we transport it
+ django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
+ if django_settings:
+ django_mod = sys.modules.get(django_settings)
+ if django_mod:
+ cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
+ self.savedDjangoEnv = True
+ self.modules.add(django_mod)
+ write(pickle.MARK)
+ self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
+ write(pickle.POP_MARK)
+
+
+ # if func is lambda, def'ed at prompt, is in main, or is nested, then
+ # we'll pickle the actual function object rather than simply saving a
+ # reference (as is done in default pickler), via save_function_tuple.
+ if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule == None:
+ #Force server to import modules that have been imported in main
+ modList = None
+ if themodule == None and not self.savedForceImports:
+ mainmod = sys.modules['__main__']
+ if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
+ modList = list(mainmod.___pyc_forcedImports__)
+ self.savedForceImports = True
+ self.save_function_tuple(obj, modList)
+ return
+ else: # func is nested
+ klass = getattr(themodule, name, None)
+ if klass is None or klass is not obj:
+ self.save_function_tuple(obj, [themodule])
+ return
+
+ if obj.__dict__:
+ # essentially save_reduce, but workaround needed to avoid recursion
+ self.save(_restore_attr)
+ write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ self.save(obj.__dict__)
+ write(pickle.TUPLE + pickle.REDUCE)
+ else:
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.FunctionType] = save_function
+
+ def save_function_tuple(self, func, forced_imports):
+ """ Pickles an actual func object.
+
+ A func comprises: code, globals, defaults, closure, and dict. We
+ extract and save these, injecting reducing functions at certain points
+ to recreate the func object. Keep in mind that some of these pieces
+ can contain a ref to the func itself. Thus, a naive save on these
+ pieces could trigger an infinite loop of save's. To get around that,
+ we first create a skeleton func object using just the code (this is
+ safe, since this won't contain a ref to the func), and memoize it as
+ soon as it's created. The other stuff can then be filled in later.
+ """
+ save = self.save
+ write = self.write
+
+ # save the modules (if any)
+ if forced_imports:
+ write(pickle.MARK)
+ save(_modules_to_main)
+ #print 'forced imports are', forced_imports
+
+ forced_names = map(lambda m: m.__name__, forced_imports)
+ save((forced_names,))
+
+ #save((forced_imports,))
+ write(pickle.REDUCE)
+ write(pickle.POP_MARK)
+
+ code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
+
+ save(_fill_function) # skeleton function updater
+ write(pickle.MARK) # beginning of tuple that _fill_function expects
+
+ # create a skeleton function object and memoize it
+ save(_make_skel_func)
+ save((code, len(closure), base_globals))
+ write(pickle.REDUCE)
+ self.memoize(func)
+
+ # save the rest of the func data needed by _fill_function
+ save(f_globals)
+ save(defaults)
+ save(closure)
+ save(dct)
+ write(pickle.TUPLE)
+ write(pickle.REDUCE) # applies _fill_function on the tuple
+
+ @staticmethod
+ def extract_code_globals(co):
+ """
+ Find all globals names read or written to by codeblock co
+ """
+ code = co.co_code
+ names = co.co_names
+ out_names = set()
+
+ n = len(code)
+ i = 0
+ extended_arg = 0
+ while i < n:
+ op = code[i]
+
+ i = i+1
+ if op >= HAVE_ARGUMENT:
+ oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
+ extended_arg = 0
+ i = i+2
+ if op == EXTENDED_ARG:
+ extended_arg = oparg*65536L
+ if op in GLOBAL_OPS:
+ out_names.add(names[oparg])
+ #print 'extracted', out_names, ' from ', names
+ return out_names
+
+ def extract_func_data(self, func):
+ """
+ Turn the function into a tuple of data necessary to recreate it:
+ code, globals, defaults, closure, dict
+ """
+ code = func.func_code
+
+ # extract all global ref's
+ func_global_refs = CloudPickler.extract_code_globals(code)
+ if code.co_consts: # see if nested function have any global refs
+ for const in code.co_consts:
+ if type(const) is types.CodeType and const.co_names:
+ func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const))
+ # process all variables referenced by global environment
+ f_globals = {}
+ for var in func_global_refs:
+ #Some names, such as class functions are not global - we don't need them
+ if func.func_globals.has_key(var):
+ f_globals[var] = func.func_globals[var]
+
+ # defaults requires no processing
+ defaults = func.func_defaults
+
+ def get_contents(cell):
+ try:
+ return cell.cell_contents
+ except ValueError, e: #cell is empty error on not yet assigned
+ raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
+
+
+ # process closure
+ if func.func_closure:
+ closure = map(get_contents, func.func_closure)
+ else:
+ closure = []
+
+ # save the dict
+ dct = func.func_dict
+
+ if printSerialization:
+ outvars = ['code: ' + str(code) ]
+ outvars.append('globals: ' + str(f_globals))
+ outvars.append('defaults: ' + str(defaults))
+ outvars.append('closure: ' + str(closure))
+ print 'function ', func, 'is extracted to: ', ', '.join(outvars)
+
+ base_globals = self.globals_ref.get(id(func.func_globals), {})
+ self.globals_ref[id(func.func_globals)] = base_globals
+
+ return (code, f_globals, defaults, closure, dct, base_globals)
+
+ def save_global(self, obj, name=None, pack=struct.pack):
+ write = self.write
+ memo = self.memo
+
+ if name is None:
+ name = obj.__name__
+
+ modname = getattr(obj, "__module__", None)
+ if modname is None:
+ modname = pickle.whichmodule(obj, name)
+
+ try:
+ __import__(modname)
+ themodule = sys.modules[modname]
+ except (ImportError, KeyError, AttributeError): #should never occur
+ raise pickle.PicklingError(
+ "Can't pickle %r: Module %s cannot be found" %
+ (obj, modname))
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ sendRef = True
+ typ = type(obj)
+ #print 'saving', obj, typ
+ try:
+ try: #Deal with case when getattribute fails with exceptions
+ klass = getattr(themodule, name)
+ except (AttributeError):
+ if modname == '__builtin__': #new.* are misrepeported
+ modname = 'new'
+ __import__(modname)
+ themodule = sys.modules[modname]
+ try:
+ klass = getattr(themodule, name)
+ except AttributeError, a:
+ #print themodule, name, obj, type(obj)
+ raise pickle.PicklingError("Can't pickle builtin %s" % obj)
+ else:
+ raise
+
+ except (ImportError, KeyError, AttributeError):
+ if typ == types.TypeType or typ == types.ClassType:
+ sendRef = False
+ else: #we can't deal with this
+ raise
+ else:
+ if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
+ sendRef = False
+ if not sendRef:
+ #note: Third party types might crash this - add better checks!
+ d = dict(obj.__dict__) #copy dict proxy to a dict
+ if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
+ d.pop('__dict__',None)
+ d.pop('__weakref__',None)
+
+ # hack as __new__ is stored differently in the __dict__
+ new_override = d.get('__new__', None)
+ if new_override:
+ d['__new__'] = obj.__new__
+
+ self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
+ d),obj=obj)
+ #print 'internal reduce dask %s %s' % (obj, d)
+ return
+
+ if self.proto >= 2:
+ code = _extension_registry.get((modname, name))
+ if code:
+ assert code > 0
+ if code <= 0xff:
+ write(pickle.EXT1 + chr(code))
+ elif code <= 0xffff:
+ write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
+ else:
+ write(pickle.EXT4 + pack("<i", code))
+ return
+
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.ClassType] = save_global
+ dispatch[types.BuiltinFunctionType] = save_global
+ dispatch[types.TypeType] = save_global
+
+ def save_instancemethod(self, obj):
+ #Memoization rarely is ever useful due to python bounding
+ self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
+ dispatch[types.MethodType] = save_instancemethod
+
+ def save_inst_logic(self, obj):
+ """Inner logic to save instance. Based off pickle.save_inst
+ Supports __transient__"""
+ cls = obj.__class__
+
+ memo = self.memo
+ write = self.write
+ save = self.save
+
+ if hasattr(obj, '__getinitargs__'):
+ args = obj.__getinitargs__()
+ len(args) # XXX Assert it's a sequence
+ pickle._keep_alive(args, memo)
+ else:
+ args = ()
+
+ write(pickle.MARK)
+
+ if self.bin:
+ save(cls)
+ for arg in args:
+ save(arg)
+ write(pickle.OBJ)
+ else:
+ for arg in args:
+ save(arg)
+ write(pickle.INST + cls.__module__ + '\n' + cls.__name__ + '\n')
+
+ self.memoize(obj)
+
+ try:
+ getstate = obj.__getstate__
+ except AttributeError:
+ stuff = obj.__dict__
+ #remove items if transient
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ stuff = stuff.copy()
+ for k in list(stuff.keys()):
+ if k in transient:
+ del stuff[k]
+ else:
+ stuff = getstate()
+ pickle._keep_alive(stuff, memo)
+ save(stuff)
+ write(pickle.BUILD)
+
+
+ def save_inst(self, obj):
+ # Hack to detect PIL Image instances without importing Imaging
+ # PIL can be loaded with multiple names, so we don't check sys.modules for it
+ if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
+ self.save_image(obj)
+ else:
+ self.save_inst_logic(obj)
+ dispatch[types.InstanceType] = save_inst
+
+ def save_property(self, obj):
+ # properties not correctly saved in python
+ self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj)
+ dispatch[property] = save_property
+
+ def save_itemgetter(self, obj):
+ """itemgetter serializer (needed for namedtuple support)
+ a bit of a pain as we need to read ctypes internals"""
+ class ItemGetterType(ctypes.Structure):
+ _fields_ = PyObject_HEAD + [
+ ('nitems', ctypes.c_size_t),
+ ('item', ctypes.py_object)
+ ]
+
+
+ itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
+ return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,))
+
+ if PyObject_HEAD:
+ dispatch[operator.itemgetter] = save_itemgetter
+
+
+
+ def save_reduce(self, func, args, state=None,
+ listitems=None, dictitems=None, obj=None):
+ """Modified to support __transient__ on new objects
+ Change only affects protocol level 2 (which is always used by PiCloud"""
+ # Assert that args is a tuple or None
+ if not isinstance(args, types.TupleType):
+ raise pickle.PicklingError("args from reduce() should be a tuple")
+
+ # Assert that func is callable
+ if not hasattr(func, '__call__'):
+ raise pickle.PicklingError("func from reduce should be callable")
+
+ save = self.save
+ write = self.write
+
+ # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
+ if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
+ #Added fix to allow transient
+ cls = args[0]
+ if not hasattr(cls, "__new__"):
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has no __new__")
+ if obj is not None and cls is not obj.__class__:
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has the wrong class")
+ args = args[1:]
+ save(cls)
+
+ #Don't pickle transient entries
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ state = state.copy()
+
+ for k in list(state.keys()):
+ if k in transient:
+ del state[k]
+
+ save(args)
+ write(pickle.NEWOBJ)
+ else:
+ save(func)
+ save(args)
+ write(pickle.REDUCE)
+
+ if obj is not None:
+ self.memoize(obj)
+
+ # More new special cases (that work with older protocols as
+ # well): when __reduce__ returns a tuple with 4 or 5 items,
+ # the 4th and 5th item should be iterators that provide list
+ # items and dict items (as (key, value) tuples), or None.
+
+ if listitems is not None:
+ self._batch_appends(listitems)
+
+ if dictitems is not None:
+ self._batch_setitems(dictitems)
+
+ if state is not None:
+ #print 'obj %s has state %s' % (obj, state)
+ save(state)
+ write(pickle.BUILD)
+
+
+ def save_xrange(self, obj):
+ """Save an xrange object in python 2.5
+ Python 2.6 supports this natively
+ """
+ range_params = xrange_params(obj)
+ self.save_reduce(_build_xrange,range_params)
+
+ #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
+ try:
+ xrange(0).__reduce__()
+ except TypeError: #can't pickle -- use PiCloud pickler
+ dispatch[xrange] = save_xrange
+
+ def save_partial(self, obj):
+ """Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
+ self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
+
+ if sys.version_info < (2,7): #2.7 supports partial pickling
+ dispatch[partial] = save_partial
+
+
+ def save_file(self, obj):
+ """Save a file"""
+ import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
+ from ..transport.adapter import SerializingAdapter
+
+ if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
+ raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
+ if obj.name == '<stdout>':
+ return self.save_reduce(getattr, (sys,'stdout'), obj=obj)
+ if obj.name == '<stderr>':
+ return self.save_reduce(getattr, (sys,'stderr'), obj=obj)
+ if obj.name == '<stdin>':
+ raise pickle.PicklingError("Cannot pickle standard input")
+ if hasattr(obj, 'isatty') and obj.isatty():
+ raise pickle.PicklingError("Cannot pickle files that map to tty objects")
+ if 'r' not in obj.mode:
+ raise pickle.PicklingError("Cannot pickle files that are not opened for reading")
+ name = obj.name
+ try:
+ fsize = os.stat(name).st_size
+ except OSError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name)
+
+ if obj.closed:
+ #create an empty closed string io
+ retval = pystringIO.StringIO("")
+ retval.close()
+ elif not fsize: #empty file
+ retval = pystringIO.StringIO("")
+ try:
+ tmpfile = file(name)
+ tst = tmpfile.read(1)
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ tmpfile.close()
+ if tst != '':
+ raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name)
+ elif fsize > SerializingAdapter.max_transmit_data:
+ raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" %
+ (name,SerializingAdapter.max_transmit_data))
+ else:
+ try:
+ tmpfile = file(name)
+ contents = tmpfile.read(SerializingAdapter.max_transmit_data)
+ tmpfile.close()
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ retval = pystringIO.StringIO(contents)
+ curloc = obj.tell()
+ retval.seek(curloc)
+
+ retval.name = name
+ self.save(retval) #save stringIO
+ self.memoize(obj)
+
+ dispatch[file] = save_file
+ """Special functions for Add-on libraries"""
+
+ def inject_numpy(self):
+ numpy = sys.modules.get('numpy')
+ if not numpy or not hasattr(numpy, 'ufunc'):
+ return
+ self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
+
+ numpy_tst_mods = ['numpy', 'scipy.special']
+ def save_ufunc(self, obj):
+ """Hack function for saving numpy ufunc objects"""
+ name = obj.__name__
+ for tst_mod_name in self.numpy_tst_mods:
+ tst_mod = sys.modules.get(tst_mod_name, None)
+ if tst_mod:
+ if name in tst_mod.__dict__:
+ self.save_reduce(_getobject, (tst_mod_name, name))
+ return
+ raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
+
+ def inject_timeseries(self):
+ """Handle bugs with pickling scikits timeseries"""
+ tseries = sys.modules.get('scikits.timeseries.tseries')
+ if not tseries or not hasattr(tseries, 'Timeseries'):
+ return
+ self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
+
+ def save_timeseries(self, obj):
+ import scikits.timeseries.tseries as ts
+
+ func, reduce_args, state = obj.__reduce__()
+ if func != ts._tsreconstruct:
+ raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
+ state = (1,
+ obj.shape,
+ obj.dtype,
+ obj.flags.fnc,
+ obj._data.tostring(),
+ ts.getmaskarray(obj).tostring(),
+ obj._fill_value,
+ obj._dates.shape,
+ obj._dates.__array__().tostring(),
+ obj._dates.dtype, #added -- preserve type
+ obj.freq,
+ obj._optinfo,
+ )
+ return self.save_reduce(_genTimeSeries, (reduce_args, state))
+
+ def inject_email(self):
+ """Block email LazyImporters from being saved"""
+ email = sys.modules.get('email')
+ if not email:
+ return
+ self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
+
+ def inject_addons(self):
+ """Plug in system. Register additional pickling functions if modules already loaded"""
+ self.inject_numpy()
+ self.inject_timeseries()
+ self.inject_email()
+
+ """Python Imaging Library"""
+ def save_image(self, obj):
+ if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
+ and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
+ #if image not loaded yet -- lazy load
+ self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
+ else:
+ #image is loaded - just transmit it over
+ self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
+
+ """
+ def memoize(self, obj):
+ pickle.Pickler.memoize(self, obj)
+ if printMemoization:
+ print 'memoizing ' + str(obj)
+ """
+
+
+
+# Shorthands for legacy support
+
+def dump(obj, file, protocol=2):
+ CloudPickler(file, protocol).dump(obj)
+
+def dumps(obj, protocol=2):
+ file = StringIO()
+
+ cp = CloudPickler(file,protocol)
+ cp.dump(obj)
+
+ #print 'cloud dumped', str(obj), str(cp.modules)
+
+ return file.getvalue()
+
+
+#hack for __import__ not working as desired
+def subimport(name):
+ __import__(name)
+ return sys.modules[name]
+
+#hack to load django settings:
+def django_settings_load(name):
+ modified_env = False
+
+ if 'DJANGO_SETTINGS_MODULE' not in os.environ:
+ os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
+ modified_env = True
+ try:
+ module = subimport(name)
+ except Exception, i:
+ print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
+ print_exec(sys.stderr)
+ if modified_env:
+ del os.environ['DJANGO_SETTINGS_MODULE']
+ else:
+ #add project directory to sys,path:
+ if hasattr(module,'__file__'):
+ dirname = os.path.split(module.__file__)[0] + '/'
+ sys.path.append(dirname)
+
+# restores function attributes
+def _restore_attr(obj, attr):
+ for key, val in attr.items():
+ setattr(obj, key, val)
+ return obj
+
+def _get_module_builtins():
+ return pickle.__builtins__
+
+def print_exec(stream):
+ ei = sys.exc_info()
+ traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
+
+def _modules_to_main(modList):
+ """Force every module in modList to be placed into main"""
+ if not modList:
+ return
+
+ main = sys.modules['__main__']
+ for modname in modList:
+ if type(modname) is str:
+ try:
+ mod = __import__(modname)
+ except Exception, i: #catch all...
+ sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
+A version mismatch is likely. Specific error was:\n' % modname)
+ print_exec(sys.stderr)
+ else:
+ setattr(main,mod.__name__, mod)
+ else:
+ #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
+ #In old version actual module was sent
+ setattr(main,modname.__name__, modname)
+
+#object generators:
+def _build_xrange(start, step, len):
+ """Built xrange explicitly"""
+ return xrange(start, start + step*len, step)
+
+def _genpartial(func, args, kwds):
+ if not args:
+ args = ()
+ if not kwds:
+ kwds = {}
+ return partial(func, *args, **kwds)
+
+
+def _fill_function(func, globals, defaults, closure, dict):
+ """ Fills in the rest of function data into the skeleton function object
+ that were created via _make_skel_func().
+ """
+ func.func_globals.update(globals)
+ func.func_defaults = defaults
+ func.func_dict = dict
+
+ if len(closure) != len(func.func_closure):
+ raise pickle.UnpicklingError("closure lengths don't match up")
+ for i in range(len(closure)):
+ _change_cell_value(func.func_closure[i], closure[i])
+
+ return func
+
+def _make_skel_func(code, num_closures, base_globals = None):
+ """ Creates a skeleton function object that contains just the provided
+ code and the correct number of cells in func_closure. All other
+ func attributes (e.g. func_globals) are empty.
+ """
+ #build closure (cells):
+ if not ctypes:
+ raise Exception('ctypes failed to import; cannot build function')
+
+ cellnew = ctypes.pythonapi.PyCell_New
+ cellnew.restype = ctypes.py_object
+ cellnew.argtypes = (ctypes.py_object,)
+ dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures)))
+
+ if base_globals is None:
+ base_globals = {}
+ base_globals['__builtins__'] = __builtins__
+
+ return types.FunctionType(code, base_globals,
+ None, None, dummy_closure)
+
+# this piece of opaque code is needed below to modify 'cell' contents
+cell_changer_code = new.code(
+ 1, 1, 2, 0,
+ ''.join([
+ chr(dis.opmap['LOAD_FAST']), '\x00\x00',
+ chr(dis.opmap['DUP_TOP']),
+ chr(dis.opmap['STORE_DEREF']), '\x00\x00',
+ chr(dis.opmap['RETURN_VALUE'])
+ ]),
+ (), (), ('newval',), '<nowhere>', 'cell_changer', 1, '', ('c',), ()
+)
+
+def _change_cell_value(cell, newval):
+ """ Changes the contents of 'cell' object to newval """
+ return new.function(cell_changer_code, {}, None, (), (cell,))(newval)
+
+"""Constructors for 3rd party libraries
+Note: These can never be renamed due to client compatibility issues"""
+
+def _getobject(modname, attribute):
+ mod = __import__(modname)
+ return mod.__dict__[attribute]
+
+def _generateImage(size, mode, str_rep):
+ """Generate image from string representation"""
+ import Image
+ i = Image.new(mode, size)
+ i.fromstring(str_rep)
+ return i
+
+def _lazyloadImage(fp):
+ import Image
+ fp.seek(0) #works in almost any case
+ return Image.open(fp)
+
+"""Timeseries"""
+def _genTimeSeries(reduce_args, state):
+ import scikits.timeseries.tseries as ts
+ from numpy import ndarray
+ from numpy.ma import MaskedArray
+
+
+ time_series = ts._tsreconstruct(*reduce_args)
+
+ #from setstate modified
+ (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
+ #print 'regenerating %s' % dtyp
+
+ MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
+ _dates = time_series._dates
+ #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
+ ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
+ _dates.freq = frq
+ _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
+ toobj=None, toord=None, tostr=None))
+ # Update the _optinfo dictionary
+ time_series._optinfo.update(infodict)
+ return time_series
+
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
new file mode 100644
index 0000000000..6172d69dcf
--- /dev/null
+++ b/python/pyspark/context.py
@@ -0,0 +1,158 @@
+import os
+import atexit
+from tempfile import NamedTemporaryFile
+
+from pyspark.broadcast import Broadcast
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.rdd import RDD
+
+from py4j.java_collections import ListConverter
+
+
+class SparkContext(object):
+ """
+ Main entry point for Spark functionality. A SparkContext represents the
+ connection to a Spark cluster, and can be used to create L{RDD}s and
+ broadcast variables on that cluster.
+ """
+
+ gateway = launch_gateway()
+ jvm = gateway.jvm
+ _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+ _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
+
+ def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
+ environment=None, batchSize=1024):
+ """
+ Create a new SparkContext.
+
+ @param master: Cluster URL to connect to
+ (e.g. mesos://host:port, spark://host:port, local[4]).
+ @param jobName: A name for your job, to display on the cluster web UI
+ @param sparkHome: Location where Spark is installed on cluster nodes.
+ @param pyFiles: Collection of .zip or .py files to send to the cluster
+ and add to PYTHONPATH. These can be paths on the local file
+ system or HDFS, HTTP, HTTPS, or FTP URLs.
+ @param environment: A dictionary of environment variables to set on
+ worker nodes.
+ @param batchSize: The number of Python objects represented as a single
+ Java object. Set 1 to disable batching or -1 to use an
+ unlimited batch size.
+ """
+ self.master = master
+ self.jobName = jobName
+ self.sparkHome = sparkHome or None # None becomes null in Py4J
+ self.environment = environment or {}
+ self.batchSize = batchSize # -1 represents a unlimited batch size
+
+ # Create the Java SparkContext through Py4J
+ empty_string_array = self.gateway.new_array(self.jvm.String, 0)
+ self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array)
+
+ self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python')
+ # Broadcast's __reduce__ method stores Broadcast instances here.
+ # This allows other code to determine which Broadcast instances have
+ # been pickled, so it can determine which Java broadcast objects to
+ # send.
+ self._pickled_broadcast_vars = set()
+
+ # Deploy any code dependencies specified in the constructor
+ for path in (pyFiles or []):
+ self.addPyFile(path)
+
+ @property
+ def defaultParallelism(self):
+ """
+ Default level of parallelism to use when not given by user (e.g. for
+ reduce tasks)
+ """
+ return self._jsc.sc().defaultParallelism()
+
+ def __del__(self):
+ if self._jsc:
+ self._jsc.stop()
+
+ def stop(self):
+ """
+ Shut down the SparkContext.
+ """
+ self._jsc.stop()
+ self._jsc = None
+
+ def parallelize(self, c, numSlices=None):
+ """
+ Distribute a local Python collection to form an RDD.
+ """
+ numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
+ tempFile = NamedTemporaryFile(delete=False)
+ atexit.register(lambda: os.unlink(tempFile.name))
+ if self.batchSize != 1:
+ c = batched(c, self.batchSize)
+ for x in c:
+ write_with_length(dump_pickle(x), tempFile)
+ tempFile.close()
+ jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self)
+
+ def textFile(self, name, minSplits=None):
+ """
+ Read a text file from HDFS, a local file system (available on all
+ nodes), or any Hadoop-supported file system URI, and return it as an
+ RDD of Strings.
+ """
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jsc.textFile(name, minSplits)
+ return RDD(jrdd, self)
+
+ def union(self, rdds):
+ """
+ Build the union of a list of RDDs.
+ """
+ first = rdds[0]._jrdd
+ rest = [x._jrdd for x in rdds[1:]]
+ rest = ListConverter().convert(rest, self.gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self)
+
+ def broadcast(self, value):
+ """
+ Broadcast a read-only variable to the cluster, returning a C{Broadcast}
+ object for reading it in distributed functions. The variable will be
+ sent to each cluster only once.
+ """
+ jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ return Broadcast(jbroadcast.id(), value, jbroadcast,
+ self._pickled_broadcast_vars)
+
+ def addFile(self, path):
+ """
+ Add a file to be downloaded into the working directory of this Spark
+ job on every node. The C{path} passed can be either a local file,
+ a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
+ HTTPS or FTP URI.
+ """
+ self._jsc.sc().addFile(path)
+
+ def clearFiles(self):
+ """
+ Clear the job's list of files added by L{addFile} or L{addPyFile} so
+ that they do not get downloaded to any new nodes.
+ """
+ # TODO: remove added .py or .zip files from the PYTHONPATH?
+ self._jsc.sc().clearFiles()
+
+ def addPyFile(self, path):
+ """
+ Add a .py or .zip dependency for all tasks to be executed on this
+ SparkContext in the future. The C{path} passed can be either a local
+ file, a file in HDFS (or other Hadoop-supported filesystems), or an
+ HTTP, HTTPS or FTP URI.
+ """
+ self.addFile(path)
+ filename = path.split("/")[-1]
+ os.environ["PYTHONPATH"] = \
+ "%s:%s" % (filename, os.environ["PYTHONPATH"])
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
new file mode 100644
index 0000000000..2329e536cc
--- /dev/null
+++ b/python/pyspark/java_gateway.py
@@ -0,0 +1,38 @@
+import os
+import sys
+from subprocess import Popen, PIPE
+from threading import Thread
+from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+
+
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
+def launch_gateway():
+ # Launch the Py4j gateway using Spark's run command so that we pick up the
+ # proper classpath and SPARK_MEM settings from spark-env.sh
+ command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer",
+ "--die-on-broken-pipe", "0"]
+ proc = Popen(command, stdout=PIPE, stdin=PIPE)
+ # Determine which ephemeral port the server started on:
+ port = int(proc.stdout.readline())
+ # Create a thread to echo output from the GatewayServer, which is required
+ # for Java log output to show up:
+ class EchoOutputThread(Thread):
+ def __init__(self, stream):
+ Thread.__init__(self)
+ self.daemon = True
+ self.stream = stream
+
+ def run(self):
+ while True:
+ line = self.stream.readline()
+ sys.stderr.write(line)
+ EchoOutputThread(proc.stdout).start()
+ # Connect to the gateway
+ gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
+ # Import the classes used by PySpark
+ java_import(gateway.jvm, "spark.api.java.*")
+ java_import(gateway.jvm, "spark.api.python.*")
+ java_import(gateway.jvm, "scala.Tuple2")
+ return gateway
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
new file mode 100644
index 0000000000..7036c47980
--- /dev/null
+++ b/python/pyspark/join.py
@@ -0,0 +1,92 @@
+"""
+Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+ * Neither the name of the Douban Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+def _do_python_join(rdd, other, numSplits, dispatch):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
+
+
+def python_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_right_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not vbuf:
+ vbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_left_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not wbuf:
+ wbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_cogroup(rdd, other, numSplits):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return (vbuf, wbuf)
+ return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
new file mode 100644
index 0000000000..cbffb6cc1f
--- /dev/null
+++ b/python/pyspark/rdd.py
@@ -0,0 +1,713 @@
+import atexit
+from base64 import standard_b64encode as b64enc
+import copy
+from collections import defaultdict
+from itertools import chain, ifilter, imap, product
+import operator
+import os
+import shlex
+from subprocess import Popen, PIPE
+from tempfile import NamedTemporaryFile
+from threading import Thread
+
+from pyspark import cloudpickle
+from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
+ read_from_pickle_file
+from pyspark.join import python_join, python_left_outer_join, \
+ python_right_outer_join, python_cogroup
+
+from py4j.java_collections import ListConverter, MapConverter
+
+
+__all__ = ["RDD"]
+
+
+class RDD(object):
+ """
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ Represents an immutable, partitioned collection of elements that can be
+ operated on in parallel.
+ """
+
+ def __init__(self, jrdd, ctx):
+ self._jrdd = jrdd
+ self.is_cached = False
+ self.ctx = ctx
+
+ @property
+ def context(self):
+ """
+ The L{SparkContext} that this RDD was created on.
+ """
+ return self.ctx
+
+ def cache(self):
+ """
+ Persist this RDD with the default storage level (C{MEMORY_ONLY}).
+ """
+ self.is_cached = True
+ self._jrdd.cache()
+ return self
+
+ # TODO persist(self, storageLevel)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+ """
+ def func(iterator): return imap(f, iterator)
+ return PipelinedRDD(self, func, preservesPartitioning)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by first applying a function to all elements of this
+ RDD, and then flattening the results.
+
+ >>> rdd = sc.parallelize([2, 3, 4])
+ >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
+ [1, 1, 1, 2, 2, 3]
+ >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
+ [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
+ """
+ def func(iterator): return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitions(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> def f(iterator): yield sum(iterator)
+ >>> rdd.mapPartitions(f).collect()
+ [3, 7]
+ """
+ return PipelinedRDD(self, f, preservesPartitioning)
+
+ # TODO: mapPartitionsWithSplit
+
+ def filter(self, f):
+ """
+ Return a new RDD containing only the elements that satisfy a predicate.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
+ >>> rdd.filter(lambda x: x % 2 == 0).collect()
+ [2, 4]
+ """
+ def func(iterator): return ifilter(f, iterator)
+ return self.mapPartitions(func)
+
+ def distinct(self):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+
+ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
+ [1, 2, 3]
+ """
+ return self.map(lambda x: (x, "")) \
+ .reduceByKey(lambda x, _: x) \
+ .map(lambda (x, _): x)
+
+ # TODO: sampling needs to be re-implemented due to Batch
+ #def sample(self, withReplacement, fraction, seed):
+ # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
+ # return RDD(jrdd, self.ctx)
+
+ #def takeSample(self, withReplacement, num, seed):
+ # vals = self._jrdd.takeSample(withReplacement, num, seed)
+ # return [load_pickle(bytes(x)) for x in vals]
+
+ def union(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> rdd.union(rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ return RDD(self._jrdd.union(other._jrdd), self.ctx)
+
+ def __add__(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> (rdd + rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ if not isinstance(other, RDD):
+ raise TypeError
+ return self.union(other)
+
+ # TODO: sort
+
+ def glom(self):
+ """
+ Return an RDD created by coalescing all elements within each partition
+ into a list.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> sorted(rdd.glom().collect())
+ [[1, 2], [3, 4]]
+ """
+ def func(iterator): yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cartesian(self, other):
+ """
+ Return the Cartesian product of this RDD and another one, that is, the
+ RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and
+ C{b} is in C{other}.
+
+ >>> rdd = sc.parallelize([1, 2])
+ >>> sorted(rdd.cartesian(rdd).collect())
+ [(1, 1), (1, 2), (2, 1), (2, 2)]
+ """
+ # Due to batching, we can't use the Java cartesian method.
+ java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ def unpack_batches(pair):
+ (x, y) = pair
+ if type(x) == Batch or type(y) == Batch:
+ xs = x.items if type(x) == Batch else [x]
+ ys = y.items if type(y) == Batch else [y]
+ for pair in product(xs, ys):
+ yield pair
+ else:
+ yield pair
+ return java_cartesian.flatMap(unpack_batches)
+
+ def groupBy(self, f, numSplits=None):
+ """
+ Return an RDD of grouped items.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
+ >>> result = rdd.groupBy(lambda x: x % 2).collect()
+ >>> sorted([(x, sorted(y)) for (x, y) in result])
+ [(0, [2, 8]), (1, [1, 1, 3, 5])]
+ """
+ return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
+
+ def pipe(self, command, env={}):
+ """
+ Return an RDD created by piping elements to a forked external process.
+
+ >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
+ ['1', '2', '3']
+ """
+ def func(iterator):
+ pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+ def pipe_objs(out):
+ for obj in iterator:
+ out.write(str(obj).rstrip('\n') + '\n')
+ out.close()
+ Thread(target=pipe_objs, args=[pipe.stdin]).start()
+ return (x.rstrip('\n') for x in pipe.stdout)
+ return self.mapPartitions(func)
+
+ def foreach(self, f):
+ """
+ Applies a function to all elements of this RDD.
+
+ >>> def f(x): print x
+ >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
+ """
+ self.map(f).collect() # Force evaluation
+
+ def collect(self):
+ """
+ Return a list that contains all of the elements in this RDD.
+ """
+ picklesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(picklesInJava))
+
+ def _collect_iterator_through_file(self, iterator):
+ # Transferring lots of data through Py4J can be slow because
+ # socket.readline() is inefficient. Instead, we'll dump the data to a
+ # file and read it back.
+ tempFile = NamedTemporaryFile(delete=False)
+ tempFile.close()
+ def clean_up_file():
+ try: os.unlink(tempFile.name)
+ except: pass
+ atexit.register(clean_up_file)
+ self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ for item in read_from_pickle_file(tempFile):
+ yield item
+ os.unlink(tempFile.name)
+
+ def reduce(self, f):
+ """
+ Reduces the elements of this RDD using the specified associative binary
+ operator.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
+ 15
+ >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
+ 10
+ """
+ def func(iterator):
+ acc = None
+ for obj in iterator:
+ if acc is None:
+ acc = obj
+ else:
+ acc = f(obj, acc)
+ if acc is not None:
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(f, vals)
+
+ def fold(self, zeroValue, op):
+ """
+ Aggregate the elements of each partition, and then the results for all
+ the partitions, using a given associative function and a neutral "zero
+ value."
+
+ The function C{op(t1, t2)} is allowed to modify C{t1} and return it
+ as its result value to avoid object allocation; however, it should not
+ modify C{t2}.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
+ 15
+ """
+ def func(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = op(obj, acc)
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(op, vals, zeroValue)
+
+ # TODO: aggregate
+
+ def sum(self):
+ """
+ Add up the elements in this RDD.
+
+ >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
+ 6.0
+ """
+ return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+
+ def count(self):
+ """
+ Return the number of elements in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).count()
+ 3
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def countByValue(self):
+ """
+ Return the count of each unique value in this RDD as a dictionary of
+ (value, count) pairs.
+
+ >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
+ [(1, 2), (2, 3)]
+ """
+ def countPartition(iterator):
+ counts = defaultdict(int)
+ for obj in iterator:
+ counts[obj] += 1
+ yield counts
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] += v
+ return m1
+ return self.mapPartitions(countPartition).reduce(mergeMaps)
+
+ def take(self, num):
+ """
+ Take the first num elements of the RDD.
+
+ This currently scans the partitions *one by one*, so it will be slow if
+ a lot of partitions are required. In that case, use L{collect} to get
+ the whole RDD instead.
+
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
+ [2, 3]
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+ [2, 3, 4, 5, 6]
+ """
+ items = []
+ splits = self._jrdd.splits()
+ taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0)
+ while len(items) < num and splits:
+ split = splits.pop(0)
+ iterator = self._jrdd.iterator(split, taskContext)
+ items.extend(self._collect_iterator_through_file(iterator))
+ return items[:num]
+
+ def first(self):
+ """
+ Return the first element in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).first()
+ 2
+ """
+ return self.take(1)[0]
+
+ def saveAsTextFile(self, path):
+ """
+ Save this RDD as a text file, using string representations of elements.
+
+ >>> tempFile = NamedTemporaryFile(delete=True)
+ >>> tempFile.close()
+ >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
+ >>> from fileinput import input
+ >>> from glob import glob
+ >>> ''.join(input(glob(tempFile.name + "/part-0000*")))
+ '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
+ """
+ def func(iterator):
+ return (str(x).encode("utf-8") for x in iterator)
+ keyed = PipelinedRDD(self, func)
+ keyed._bypass_serializer = True
+ keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+
+ # Pair functions
+
+ def collectAsMap(self):
+ """
+ Return the key-value pairs in this RDD to the master as a dictionary.
+
+ >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
+ >>> m[1]
+ 2
+ >>> m[3]
+ 4
+ """
+ return dict(self.collect())
+
+ def reduceByKey(self, func, numSplits=None):
+ """
+ Merge the values for each key using an associative reduce function.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ Output will be hash-partitioned with C{numSplits} splits, or the
+ default parallelism level if C{numSplits} is not specified.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKey(add).collect())
+ [('a', 2), ('b', 1)]
+ """
+ return self.combineByKey(lambda x: x, func, func, numSplits)
+
+ def reduceByKeyLocally(self, func):
+ """
+ Merge the values for each key using an associative reduce function, but
+ return the results immediately to the master as a dictionary.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKeyLocally(add).items())
+ [('a', 2), ('b', 1)]
+ """
+ def reducePartition(iterator):
+ m = {}
+ for (k, v) in iterator:
+ m[k] = v if k not in m else func(m[k], v)
+ yield m
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] = v if k not in m1 else func(m1[k], v)
+ return m1
+ return self.mapPartitions(reducePartition).reduce(mergeMaps)
+
+ def countByKey(self):
+ """
+ Count the number of elements for each key, and return the result to the
+ master as a dictionary.
+
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.countByKey().items())
+ [('a', 2), ('b', 1)]
+ """
+ return self.map(lambda x: x[0]).countByValue()
+
+ def join(self, other, numSplits=None):
+ """
+ Return an RDD containing all pairs of elements with matching keys in
+ C{self} and C{other}.
+
+ Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
+ (k, v1) is in C{self} and (k, v2) is in C{other}.
+
+ Performs a hash join across the cluster.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2), ("a", 3)])
+ >>> sorted(x.join(y).collect())
+ [('a', (1, 2)), ('a', (1, 3))]
+ """
+ return python_join(self, other, numSplits)
+
+ def leftOuterJoin(self, other, numSplits=None):
+ """
+ Perform a left outer join of C{self} and C{other}.
+
+ For each element (k, v) in C{self}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for w in C{other}, or the pair
+ (k, (v, None)) if no elements in other have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.leftOuterJoin(y).collect())
+ [('a', (1, 2)), ('b', (4, None))]
+ """
+ return python_left_outer_join(self, other, numSplits)
+
+ def rightOuterJoin(self, other, numSplits=None):
+ """
+ Perform a right outer join of C{self} and C{other}.
+
+ For each element (k, w) in C{other}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
+ if no elements in C{self} have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(y.rightOuterJoin(x).collect())
+ [('a', (2, 1)), ('b', (None, 4))]
+ """
+ return python_right_outer_join(self, other, numSplits)
+
+ # TODO: add option to control map-side combining
+ def partitionBy(self, numSplits, hashFunc=hash):
+ """
+ Return a copy of the RDD partitioned using the specified partitioner.
+
+ >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+ >>> sets = pairs.partitionBy(2).glom().collect()
+ >>> set(sets[0]).intersection(set(sets[1]))
+ set([])
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ # Transferring O(n) objects to Java is too expensive. Instead, we'll
+ # form the hash buckets in Python, transferring O(numSplits) objects
+ # to Java. Each object is a (splitNumber, [objects]) pair.
+ def add_shuffle_key(iterator):
+ buckets = defaultdict(list)
+ for (k, v) in iterator:
+ buckets[hashFunc(k) % numSplits].append((k, v))
+ for (split, items) in buckets.iteritems():
+ yield str(split)
+ yield dump_pickle(Batch(items))
+ keyed = PipelinedRDD(self, add_shuffle_key)
+ keyed._bypass_serializer = True
+ pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
+ jrdd = pairRDD.partitionBy(partitioner)
+ jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
+ return RDD(jrdd, self.ctx)
+
+ # TODO: add control over map-side aggregation
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numSplits=None):
+ """
+ Generic function to combine the elements for each key using a custom
+ set of aggregation functions.
+
+ Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
+ type" C. Note that V and C can be different -- for example, one might
+ group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]).
+
+ Users provide three functions:
+
+ - C{createCombiner}, which turns a V into a C (e.g., creates
+ a one-element list)
+ - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of
+ a list)
+ - C{mergeCombiners}, to combine two C's into a single one.
+
+ In addition, users can control the partitioning of the output RDD.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> def f(x): return x
+ >>> def add(a, b): return a + str(b)
+ >>> sorted(x.combineByKey(str, add, add).collect())
+ [('a', '11'), ('b', '1')]
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ def combineLocally(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if k not in combiners:
+ combiners[k] = createCombiner(v)
+ else:
+ combiners[k] = mergeValue(combiners[k], v)
+ return combiners.iteritems()
+ locally_combined = self.mapPartitions(combineLocally)
+ shuffled = locally_combined.partitionBy(numSplits)
+ def _mergeCombiners(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if not k in combiners:
+ combiners[k] = v
+ else:
+ combiners[k] = mergeCombiners(combiners[k], v)
+ return combiners.iteritems()
+ return shuffled.mapPartitions(_mergeCombiners)
+
+ # TODO: support variant with custom partitioner
+ def groupByKey(self, numSplits=None):
+ """
+ Group the values for each key in the RDD into a single sequence.
+ Hash-partitions the resulting RDD with into numSplits partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(x.groupByKey().collect())
+ [('a', [1, 1]), ('b', [1])]
+ """
+
+ def createCombiner(x):
+ return [x]
+
+ def mergeValue(xs, x):
+ xs.append(x)
+ return xs
+
+ def mergeCombiners(a, b):
+ return a + b
+
+ return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
+ numSplits)
+
+ # TODO: add tests
+ def flatMapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a flatMap function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def mapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a map function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ # TODO: support varargs cogroup of several RDDs.
+ def groupWith(self, other):
+ """
+ Alias for cogroup.
+ """
+ return self.cogroup(other)
+
+ # TODO: add variant with custom parittioner
+ def cogroup(self, other, numSplits=None):
+ """
+ For each key k in C{self} or C{other}, return a resulting RDD that
+ contains a tuple with the list of values for that key in C{self} as well
+ as C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.cogroup(y).collect())
+ [('a', ([1], [2])), ('b', ([4], []))]
+ """
+ return python_cogroup(self, other, numSplits)
+
+ # TODO: `lookup` is disabled because we can't make direct comparisons based
+ # on the key; we need to compare the hash of the key to the hash of the
+ # keys in the pairs. This could be an expensive operation, since those
+ # hashes aren't retained.
+
+
+class PipelinedRDD(RDD):
+ """
+ Pipelined maps:
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+
+ Pipelined reduces:
+ >>> from operator import add
+ >>> rdd.map(lambda x: 2 * x).reduce(add)
+ 20
+ >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
+ 20
+ """
+ def __init__(self, prev, func, preservesPartitioning=False):
+ if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ prev_func = prev.func
+ def pipeline_func(iterator):
+ return func(prev_func(iterator))
+ self.func = pipeline_func
+ self.preservesPartitioning = \
+ prev.preservesPartitioning and preservesPartitioning
+ self._prev_jrdd = prev._prev_jrdd
+ else:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self.is_cached = False
+ self.ctx = prev.ctx
+ self.prev = prev
+ self._jrdd_val = None
+ self._bypass_serializer = False
+
+ @property
+ def _jrdd(self):
+ if self._jrdd_val:
+ return self._jrdd_val
+ func = self.func
+ if not self._bypass_serializer and self.ctx.batchSize != 1:
+ oldfunc = self.func
+ batchSize = self.ctx.batchSize
+ def batched_func(iterator):
+ return batched(oldfunc(iterator), batchSize)
+ func = batched_func
+ cmds = [func, self._bypass_serializer]
+ pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
+ self.ctx.gateway._gateway_client)
+ self.ctx._pickled_broadcast_vars.clear()
+ class_manifest = self._prev_jrdd.classManifest()
+ env = copy.copy(self.ctx.environment)
+ env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
+ env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
+ python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ broadcast_vars, class_manifest)
+ self._jrdd_val = python_rdd.asJavaRDD()
+ return self._jrdd_val
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
new file mode 100644
index 0000000000..9a5151ea00
--- /dev/null
+++ b/python/pyspark/serializers.py
@@ -0,0 +1,78 @@
+import struct
+import cPickle
+
+
+class Batch(object):
+ """
+ Used to store multiple RDD entries as a single Java object.
+
+ This relieves us from having to explicitly track whether an RDD
+ is stored as batches of objects and avoids problems when processing
+ the union() of batched and unbatched RDDs (e.g. the union() of textFile()
+ with another RDD).
+ """
+ def __init__(self, items):
+ self.items = items
+
+
+def batched(iterator, batchSize):
+ if batchSize == -1: # unlimited batch size
+ yield Batch(list(iterator))
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == batchSize:
+ yield Batch(items)
+ items = []
+ count = 0
+ if items:
+ yield Batch(items)
+
+
+def dump_pickle(obj):
+ return cPickle.dumps(obj, 2)
+
+
+load_pickle = cPickle.loads
+
+
+def read_long(stream):
+ length = stream.read(8)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!q", length)[0]
+
+
+def read_int(stream):
+ length = stream.read(4)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!i", length)[0]
+
+def write_with_length(obj, stream):
+ stream.write(struct.pack("!i", len(obj)))
+ stream.write(obj)
+
+
+def read_with_length(stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return obj
+
+
+def read_from_pickle_file(stream):
+ try:
+ while True:
+ obj = load_pickle(read_with_length(stream))
+ if type(obj) == Batch: # We don't care about inheritance
+ for item in obj.items:
+ yield item
+ else:
+ yield obj
+ except EOFError:
+ return
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
new file mode 100644
index 0000000000..bd39b0283f
--- /dev/null
+++ b/python/pyspark/shell.py
@@ -0,0 +1,33 @@
+"""
+An interactive shell.
+"""
+import optparse # I prefer argparse, but it's not included with Python < 2.7
+import code
+import sys
+
+from pyspark.context import SparkContext
+
+
+def main(master='local', ipython=False):
+ sc = SparkContext(master, 'PySparkShell')
+ user_ns = {'sc' : sc}
+ banner = "Spark context avaiable as sc."
+ if ipython:
+ import IPython
+ IPython.embed(user_ns=user_ns, banner2=banner)
+ else:
+ print banner
+ code.interact(local=user_ns)
+
+
+if __name__ == '__main__':
+ usage = "usage: %prog [options] master"
+ parser = optparse.OptionParser(usage=usage)
+ parser.add_option("-i", "--ipython", help="Run IPython shell",
+ action="store_true")
+ (options, args) = parser.parse_args()
+ if len(sys.argv) > 1:
+ master = args[0]
+ else:
+ master = 'local'
+ main(master, options.ipython)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
new file mode 100644
index 0000000000..9f6b507dbd
--- /dev/null
+++ b/python/pyspark/worker.py
@@ -0,0 +1,40 @@
+"""
+Worker that receives input from Piped RDD.
+"""
+import sys
+from base64 import standard_b64decode
+# CloudPickler needs to be imported so that depicklers are registered using the
+# copy_reg module.
+from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import write_with_length, read_with_length, \
+ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+
+
+# Redirect stdout to stderr so that users must return values from functions.
+old_stdout = sys.stdout
+sys.stdout = sys.stderr
+
+
+def load_obj():
+ return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+
+
+def main():
+ num_broadcast_variables = read_int(sys.stdin)
+ for _ in range(num_broadcast_variables):
+ bid = read_long(sys.stdin)
+ value = read_with_length(sys.stdin)
+ _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ func = load_obj()
+ bypassSerializer = load_obj()
+ if bypassSerializer:
+ dumps = lambda x: x
+ else:
+ dumps = dump_pickle
+ for obj in func(read_from_pickle_file(sys.stdin)):
+ write_with_length(dumps(obj), old_stdout)
+
+
+if __name__ == '__main__':
+ main()