aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-02-02 01:57:18 -0600
committerStephen Haberman <stephen@exigencecorp.com>2013-02-02 01:57:18 -0600
commit103c375ba044b4fb1061298d6375587ed30832a4 (patch)
tree45b1a45260b0d266965caed1f4612f0b2eb49246 /python
parentfdec42385a1a8f10f9dd803525cb3c132a25ba53 (diff)
parentae26911ec0d768dcdae8b7d706ca4544e36535e6 (diff)
downloadspark-103c375ba044b4fb1061298d6375587ed30832a4.tar.gz
spark-103c375ba044b4fb1061298d6375587ed30832a4.tar.bz2
spark-103c375ba044b4fb1061298d6375587ed30832a4.zip
Merge branch 'master' into sparkmem
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf2
-rwxr-xr-xpython/examples/als.py71
-rw-r--r--python/pyspark/__init__.py9
-rw-r--r--python/pyspark/accumulators.py207
-rw-r--r--python/pyspark/context.py147
-rw-r--r--python/pyspark/files.py38
-rw-r--r--python/pyspark/rdd.py71
-rw-r--r--python/pyspark/serializers.py7
-rw-r--r--python/pyspark/shell.py5
-rw-r--r--python/pyspark/tests.py121
-rw-r--r--python/pyspark/worker.py27
-rwxr-xr-xpython/run-tests9
-rwxr-xr-xpython/test_support/hello.txt1
-rwxr-xr-xpython/test_support/userlibrary.py7
14 files changed, 672 insertions, 50 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
- pyspark.java_gateway pyspark.examples pyspark.shell
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/examples/als.py b/python/examples/als.py
new file mode 100755
index 0000000000..010f80097f
--- /dev/null
+++ b/python/examples/als.py
@@ -0,0 +1,71 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from os.path import realpath
+import sys
+
+import numpy as np
+from numpy.random import rand
+from numpy import matrix
+from pyspark import SparkContext
+
+LAMBDA = 0.01 # regularization
+np.random.seed(42)
+
+def rmse(R, ms, us):
+ diff = R - ms * us.T
+ return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
+
+def update(i, vec, mat, ratings):
+ uu = mat.shape[0]
+ ff = mat.shape[1]
+ XtX = matrix(np.zeros((ff, ff)))
+ Xty = np.zeros((ff, 1))
+
+ for j in range(uu):
+ v = mat[j, :]
+ XtX += v.T * v
+ Xty += v.T * ratings[i, j]
+ XtX += np.eye(ff, ff) * LAMBDA * uu
+ return np.linalg.solve(XtX, Xty)
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, \
+ "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
+ M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
+ U = int(sys.argv[3]) if len(sys.argv) > 3 else 500
+ F = int(sys.argv[4]) if len(sys.argv) > 4 else 10
+ ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5
+ slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2
+
+ print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
+ (M, U, F, ITERATIONS, slices)
+
+ R = matrix(rand(M, F)) * matrix(rand(U, F).T)
+ ms = matrix(rand(M ,F))
+ us = matrix(rand(U, F))
+
+ Rb = sc.broadcast(R)
+ msb = sc.broadcast(ms)
+ usb = sc.broadcast(us)
+
+ for i in range(ITERATIONS):
+ ms = sc.parallelize(range(M), slices) \
+ .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
+ .collect()
+ ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
+ # a 3-d array, we take the first 2 dims for the matrix
+ msb = sc.broadcast(ms)
+
+ us = sc.parallelize(range(U), slices) \
+ .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \
+ .collect()
+ us = matrix(np.array(us)[:, :, 0])
+ usb = sc.broadcast(us)
+
+ error = rmse(R, ms, us)
+ print "Iteration %d:" % i
+ print "\nRMSE: %5.4f\n" % error
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index c595ae0842..3e8bca62f0 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -7,6 +7,12 @@ Public classes:
Main entry point for Spark functionality.
- L{RDD<pyspark.rdd.RDD>}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ - L{Broadcast<pyspark.broadcast.Broadcast>}
+ A broadcast variable that gets reused across tasks.
+ - L{Accumulator<pyspark.accumulators.Accumulator>}
+ An "add-only" shared variable that tasks can only add values to.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
"""
import sys
import os
@@ -15,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg
from pyspark.context import SparkContext
from pyspark.rdd import RDD
+from pyspark.files import SparkFiles
-__all__ = ["SparkContext", "RDD"]
+__all__ = ["SparkContext", "RDD", "SparkFiles"]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
new file mode 100644
index 0000000000..61fcbbd376
--- /dev/null
+++ b/python/pyspark/accumulators.py
@@ -0,0 +1,207 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> a = sc.accumulator(1)
+>>> a.value
+1
+>>> a.value = 2
+>>> a.value
+2
+>>> a += 5
+>>> a.value
+7
+
+>>> sc.accumulator(1.0).value
+1.0
+
+>>> sc.accumulator(1j).value
+1j
+
+>>> rdd = sc.parallelize([1,2,3])
+>>> def f(x):
+... global a
+... a += x
+>>> rdd.foreach(f)
+>>> a.value
+13
+
+>>> from pyspark.accumulators import AccumulatorParam
+>>> class VectorAccumulatorParam(AccumulatorParam):
+... def zero(self, value):
+... return [0.0] * len(value)
+... def addInPlace(self, val1, val2):
+... for i in xrange(len(val1)):
+... val1[i] += val2[i]
+... return val1
+>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
+>>> va.value
+[1.0, 2.0, 3.0]
+>>> def g(x):
+... global va
+... va += [x] * 3
+>>> rdd.foreach(g)
+>>> va.value
+[7.0, 8.0, 9.0]
+
+>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> def h(x):
+... global a
+... a.value = 7
+>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Exception:...
+"""
+
+import struct
+import SocketServer
+import threading
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import read_int, read_with_length, load_pickle
+
+
+# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
+# the local accumulator updates back to the driver program at the end of a task.
+_accumulatorRegistry = {}
+
+
+def _deserialize_accumulator(aid, zero_value, accum_param):
+ from pyspark.accumulators import _accumulatorRegistry
+ accum = Accumulator(aid, zero_value, accum_param)
+ accum._deserialized = True
+ _accumulatorRegistry[aid] = accum
+ return accum
+
+
+class Accumulator(object):
+ """
+ A shared variable that can be accumulated, i.e., has a commutative and associative "add"
+ operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
+ operator, but only the driver program is allowed to access its value, using C{value}.
+ Updates from the workers get propagated automatically to the driver program.
+
+ While C{SparkContext} supports accumulators for primitive data types like C{int} and
+ C{float}, users can also define accumulators for custom types by providing a custom
+ L{AccumulatorParam} object. Refer to the doctest of this module for an example.
+ """
+
+ def __init__(self, aid, value, accum_param):
+ """Create a new Accumulator with a given initial value and AccumulatorParam object"""
+ from pyspark.accumulators import _accumulatorRegistry
+ self.aid = aid
+ self.accum_param = accum_param
+ self._value = value
+ self._deserialized = False
+ _accumulatorRegistry[aid] = self
+
+ def __reduce__(self):
+ """Custom serialization; saves the zero value from our AccumulatorParam"""
+ param = self.accum_param
+ return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
+
+ @property
+ def value(self):
+ """Get the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ """Sets the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ self._value = value
+
+ def __iadd__(self, term):
+ """The += operator; adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+ return self
+
+ def __str__(self):
+ return str(self._value)
+
+ def __repr__(self):
+ return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
+
+
+class AccumulatorParam(object):
+ """
+ Helper object that defines how to accumulate values of a given type.
+ """
+
+ def zero(self, value):
+ """
+ Provide a "zero value" for the type, compatible in dimensions with the
+ provided C{value} (e.g., a zero vector)
+ """
+ raise NotImplementedError
+
+ def addInPlace(self, value1, value2):
+ """
+ Add two values of the accumulator's data type, returning a new value;
+ for efficiency, can also update C{value1} in place and return it.
+ """
+ raise NotImplementedError
+
+
+class AddingAccumulatorParam(AccumulatorParam):
+ """
+ An AccumulatorParam that uses the + operators to add values. Designed for simple types
+ such as integers, floats, and lists. Requires the zero value for the underlying type
+ as a parameter.
+ """
+
+ def __init__(self, zero_value):
+ self.zero_value = zero_value
+
+ def zero(self, value):
+ return self.zero_value
+
+ def addInPlace(self, value1, value2):
+ value1 += value2
+ return value1
+
+
+# Singleton accumulator params for some standard types
+INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
+FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
+COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
+
+
+class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+ def handle(self):
+ from pyspark.accumulators import _accumulatorRegistry
+ num_updates = read_int(self.rfile)
+ for _ in range(num_updates):
+ (aid, update) = load_pickle(read_with_length(self.rfile))
+ _accumulatorRegistry[aid] += update
+ # Write a byte in acknowledgement
+ self.wfile.write(struct.pack("!b", 1))
+
+
+def _start_update_server():
+ """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
+ server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
+ thread = threading.Thread(target=server.serve_forever)
+ thread.daemon = True
+ thread.start()
+ return server
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index e486f206b0..6831f9b7f8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,8 +1,13 @@
import os
-import atexit
+import shutil
+import sys
+from threading import Lock
from tempfile import NamedTemporaryFile
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
@@ -17,11 +22,13 @@ class SparkContext(object):
broadcast variables on that cluster.
"""
- gateway = launch_gateway()
- jvm = gateway.jvm
- _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
- _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
- _takePartition = jvm.PythonRDD.takePartition
+ _gateway = None
+ _jvm = None
+ _writeIteratorToPickleFile = None
+ _takePartition = None
+ _next_accum_id = 0
+ _active_spark_context = None
+ _lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -41,6 +48,18 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
+ with SparkContext._lock:
+ if SparkContext._active_spark_context:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = self
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeIteratorToPickleFile = \
+ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._takePartition = \
+ SparkContext._jvm.PythonRDD.takePartition
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -48,10 +67,18 @@ class SparkContext(object):
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
- empty_string_array = self.gateway.new_array(self.jvm.String, 0)
- self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array = self._gateway.new_array(self._jvm.String, 0)
+ self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
+ # Create a single Accumulator in Java that we'll send all our updates through;
+ # they will be passed back to us through a TCP server
+ self._accumulatorServer = accumulators._start_update_server()
+ (host, port) = self._accumulatorServer.server_address
+ self._javaAccumulator = self._jsc.accumulator(
+ self._jvm.java.util.ArrayList(),
+ self._jvm.PythonAccumulatorParam(host, port))
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
@@ -62,6 +89,13 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
+ SparkFiles._sc = self
+ sys.path.append(SparkFiles.getRootDirectory())
+
+ # Create a temporary directory inside spark.local.dir:
+ local_dir = self._jvm.spark.Utils.getLocalDir()
+ self._temp_dir = \
+ self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
@@ -72,15 +106,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism()
def __del__(self):
- if self._jsc:
- self._jsc.stop()
+ self.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
- self._jsc.stop()
- self._jsc = None
+ if self._jsc:
+ self._jsc.stop()
+ self._jsc = None
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
+ self._accumulatorServer = None
+ with SparkContext._lock:
+ SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None):
"""
@@ -90,14 +129,14 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False)
- atexit.register(lambda: os.unlink(tempFile.name))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
+ jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
@@ -110,6 +149,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -129,12 +172,48 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
+ def accumulator(self, value, accum_param=None):
+ """
+ Create an L{Accumulator} with the given initial value, using a given
+ L{AccumulatorParam} helper object to define how to add values of the
+ data type if provided. Default AccumulatorParams are used for integers
+ and floating-point numbers if you do not provide one. For other types,
+ a custom AccumulatorParam can be used.
+ """
+ if accum_param == None:
+ if isinstance(value, int):
+ accum_param = accumulators.INT_ACCUMULATOR_PARAM
+ elif isinstance(value, float):
+ accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+ elif isinstance(value, complex):
+ accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+ else:
+ raise Exception("No default accumulator param for type %s" % type(value))
+ SparkContext._next_accum_id += 1
+ return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
def addFile(self, path):
"""
- Add a file to be downloaded into the working directory of this Spark
- job on every node. The C{path} passed can be either a local file,
- a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
- HTTPS or FTP URI.
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
"""
self._jsc.sc().addFile(path)
@@ -155,5 +234,31 @@ class SparkContext(object):
"""
self.addFile(path)
filename = path.split("/")[-1]
- os.environ["PYTHONPATH"] = \
- "%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import atexit
+ import doctest
+ import tempfile
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
new file mode 100644
index 0000000000..001b7a28b6
--- /dev/null
+++ b/python/pyspark/files.py
@@ -0,0 +1,38 @@
+import os
+
+
+class SparkFiles(object):
+ """
+ Resolves paths to files added through
+ L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
+
+ SparkFiles contains only classmethods; users should not create SparkFiles
+ instances.
+ """
+
+ _root_directory = None
+ _is_running_on_worker = False
+ _sc = None
+
+ def __init__(self):
+ raise NotImplementedError("Do not construct SparkFiles objects")
+
+ @classmethod
+ def get(cls, filename):
+ """
+ Get the absolute path of a file added through C{SparkContext.addFile()}.
+ """
+ path = os.path.join(SparkFiles.getRootDirectory(), filename)
+ return os.path.abspath(path)
+
+ @classmethod
+ def getRootDirectory(cls):
+ """
+ Get the root directory that contains files added through
+ C{SparkContext.addFile()}.
+ """
+ if cls._is_running_on_worker:
+ return cls._root_directory
+ else:
+ # This will have to change if we support multiple SparkContexts:
+ return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1d36da42b0..41ea6e6e14 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,4 +1,3 @@
-import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
@@ -32,7 +31,9 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
+ self._partitionFunc = None
@property
def context(self):
@@ -49,6 +50,34 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
@@ -234,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
- tempFile = NamedTemporaryFile(delete=False)
+ tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- def clean_up_file():
- try: os.unlink(tempFile.name)
- except: pass
- atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
@@ -377,7 +402,7 @@ class RDD(object):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
- keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+ keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
@@ -497,7 +522,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -514,17 +539,21 @@ class RDD(object):
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[hashFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- jrdd = pairRDD.partitionBy(partitioner)
- jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
@@ -662,7 +691,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -675,6 +704,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -695,18 +725,21 @@ class PipelinedRDD(RDD):
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx.gateway._gateway_client)
+ self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
- python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, class_manifest)
+ broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9a5151ea00..115cf28cc2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -52,8 +52,13 @@ def read_int(stream):
raise EOFError
return struct.unpack("!i", length)[0]
+
+def write_int(value, stream):
+ stream.write(struct.pack("!i", value))
+
+
def write_with_length(obj, stream):
- stream.write(struct.pack("!i", len(obj)))
+ write_int(len(obj), stream)
stream.write(obj)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 7e6ad3aa76..54ff1bf8e7 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -1,9 +1,10 @@
"""
An interactive shell.
-This fle is designed to be launched as a PYTHONSTARTUP script.
+This file is designed to be launched as a PYTHONSTARTUP script.
"""
import os
+import pyspark
from pyspark.context import SparkContext
@@ -14,4 +15,4 @@ print "Spark context avaiable as sc."
# which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
if _pythonstartup and os.path.isfile(_pythonstartup):
- execfile(_pythonstartup)
+ execfile(_pythonstartup)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..6a1962d267
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,121 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+import sys
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import SPARK_HOME
+
+
+class PySparkTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+ sys.path = self._old_sys_path
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc._jvm.System.clearProperty("spark.driver.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
+
+class TestIO(PySparkTestCase):
+
+ def test_stdout_redirection(self):
+ import subprocess
+ def func(x):
+ subprocess.check_call('ls', shell=True)
+ self.sc.parallelize([1]).foreach(func)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3d792bbaa2..812e7a9da5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1,19 +1,23 @@
"""
Worker that receives input from Piped RDD.
"""
+import os
import sys
+import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
+from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import write_with_length, read_with_length, \
+from pyspark.files import SparkFiles
+from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = sys.stdout
-sys.stdout = sys.stderr
+old_stdout = os.fdopen(os.dup(1), 'w')
+os.dup2(2, 1)
def load_obj():
@@ -22,6 +26,10 @@ def load_obj():
def main():
split_index = read_int(sys.stdin)
+ spark_files_dir = load_pickle(read_with_length(sys.stdin))
+ SparkFiles._root_directory = spark_files_dir
+ SparkFiles._is_running_on_worker = True
+ sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
@@ -34,8 +42,17 @@ def main():
else:
dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin)
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ try:
+ for obj in func(split_index, iterator):
+ write_with_length(dumps(obj), old_stdout)
+ except Exception as e:
+ write_int(-2, old_stdout)
+ write_with_length(traceback.format_exc(), old_stdout)
+ sys.exit(-1)
+ # Mark the beginning of the accumulators section of the output
+ write_int(-1, old_stdout)
+ for aid, accum in _accumulatorRegistry.items():
+ write_with_length(dump_pickle((aid, accum._value)), old_stdout)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index fcdd1e27a7..a3a9ff5dcb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -8,9 +8,18 @@ FAILED=0
$FWDIR/pyspark pyspark/rdd.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark pyspark/context.py
+FAILED=$(($?||$FAILED))
+
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m doctest pyspark/accumulators.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt
new file mode 100755
index 0000000000..980a0d5f19
--- /dev/null
+++ b/python/test_support/hello.txt
@@ -0,0 +1 @@
+Hello World!
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
new file mode 100755
index 0000000000..5bb6f5009f
--- /dev/null
+++ b/python/test_support/userlibrary.py
@@ -0,0 +1,7 @@
+"""
+Used to test shipping of code depenencies with SparkContext.addPyFile().
+"""
+
+class UserClass(object):
+ def hello(self):
+ return "Hello World!"