diff options
author | Stephen Haberman <stephen@exigencecorp.com> | 2013-02-02 01:57:18 -0600 |
---|---|---|
committer | Stephen Haberman <stephen@exigencecorp.com> | 2013-02-02 01:57:18 -0600 |
commit | 103c375ba044b4fb1061298d6375587ed30832a4 (patch) | |
tree | 45b1a45260b0d266965caed1f4612f0b2eb49246 /python | |
parent | fdec42385a1a8f10f9dd803525cb3c132a25ba53 (diff) | |
parent | ae26911ec0d768dcdae8b7d706ca4544e36535e6 (diff) | |
download | spark-103c375ba044b4fb1061298d6375587ed30832a4.tar.gz spark-103c375ba044b4fb1061298d6375587ed30832a4.tar.bz2 spark-103c375ba044b4fb1061298d6375587ed30832a4.zip |
Merge branch 'master' into sparkmem
Diffstat (limited to 'python')
-rw-r--r-- | python/epydoc.conf | 2 | ||||
-rwxr-xr-x | python/examples/als.py | 71 | ||||
-rw-r--r-- | python/pyspark/__init__.py | 9 | ||||
-rw-r--r-- | python/pyspark/accumulators.py | 207 | ||||
-rw-r--r-- | python/pyspark/context.py | 147 | ||||
-rw-r--r-- | python/pyspark/files.py | 38 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 71 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 7 | ||||
-rw-r--r-- | python/pyspark/shell.py | 5 | ||||
-rw-r--r-- | python/pyspark/tests.py | 121 | ||||
-rw-r--r-- | python/pyspark/worker.py | 27 | ||||
-rwxr-xr-x | python/run-tests | 9 | ||||
-rwxr-xr-x | python/test_support/hello.txt | 1 | ||||
-rwxr-xr-x | python/test_support/userlibrary.py | 7 |
14 files changed, 672 insertions, 50 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/examples/als.py b/python/examples/als.py new file mode 100755 index 0000000000..010f80097f --- /dev/null +++ b/python/examples/als.py @@ -0,0 +1,71 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from os.path import realpath +import sys + +import numpy as np +from numpy.random import rand +from numpy import matrix +from pyspark import SparkContext + +LAMBDA = 0.01 # regularization +np.random.seed(42) + +def rmse(R, ms, us): + diff = R - ms * us.T + return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + +def update(i, vec, mat, ratings): + uu = mat.shape[0] + ff = mat.shape[1] + XtX = matrix(np.zeros((ff, ff))) + Xty = np.zeros((ff, 1)) + + for j in range(uu): + v = mat[j, :] + XtX += v.T * v + Xty += v.T * ratings[i, j] + XtX += np.eye(ff, ff) * LAMBDA * uu + return np.linalg.solve(XtX, Xty) + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, \ + "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)]) + M = int(sys.argv[2]) if len(sys.argv) > 2 else 100 + U = int(sys.argv[3]) if len(sys.argv) > 3 else 500 + F = int(sys.argv[4]) if len(sys.argv) > 4 else 10 + ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5 + slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2 + + print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \ + (M, U, F, ITERATIONS, slices) + + R = matrix(rand(M, F)) * matrix(rand(U, F).T) + ms = matrix(rand(M ,F)) + us = matrix(rand(U, F)) + + Rb = sc.broadcast(R) + msb = sc.broadcast(ms) + usb = sc.broadcast(us) + + for i in range(ITERATIONS): + ms = sc.parallelize(range(M), slices) \ + .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \ + .collect() + ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being + # a 3-d array, we take the first 2 dims for the matrix + msb = sc.broadcast(ms) + + us = sc.parallelize(range(U), slices) \ + .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \ + .collect() + us = matrix(np.array(us)[:, :, 0]) + usb = sc.broadcast(us) + + error = rmse(R, ms, us) + print "Iteration %d:" % i + print "\nRMSE: %5.4f\n" % error diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..3e8bca62f0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,12 @@ Public classes: Main entry point for Spark functionality. - L{RDD<pyspark.rdd.RDD>} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast<pyspark.broadcast.Broadcast>} + A broadcast variable that gets reused across tasks. + - L{Accumulator<pyspark.accumulators.Accumulator>} + An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles<pyspark.files.SparkFiles>} + Access files shipped with jobs. """ import sys import os @@ -15,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD +from pyspark.files import SparkFiles -__all__ = ["SparkContext", "RDD"] +__all__ = ["SparkContext", "RDD", "SparkFiles"] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..61fcbbd376 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,207 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> sc.accumulator(1.0).value +1.0 + +>>> sc.accumulator(1j).value +1j + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> from pyspark.accumulators import AccumulatorParam +>>> class VectorAccumulatorParam(AccumulatorParam): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + """ + A shared variable that can be accumulated, i.e., has a commutative and associative "add" + operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} + operator, but only the driver program is allowed to access its value, using C{value}. + Updates from the workers get propagated automatically to the driver program. + + While C{SparkContext} supports accumulators for primitive data types like C{int} and + C{float}, users can also define accumulators for custom types by providing a custom + L{AccumulatorParam} object. Refer to the doctest of this module for an example. + """ + + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + def __repr__(self): + return "Accumulator<id=%i, value=%s>" % (self.aid, self._value) + + +class AccumulatorParam(object): + """ + Helper object that defines how to accumulate values of a given type. + """ + + def zero(self, value): + """ + Provide a "zero value" for the type, compatible in dimensions with the + provided C{value} (e.g., a zero vector) + """ + raise NotImplementedError + + def addInPlace(self, value1, value2): + """ + Add two values of the accumulator's data type, returning a new value; + for efficiency, can also update C{value1} in place and return it. + """ + raise NotImplementedError + + +class AddingAccumulatorParam(AccumulatorParam): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..6831f9b7f8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,8 +1,13 @@ import os -import atexit +import shutil +import sys +from threading import Lock from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -17,11 +22,13 @@ class SparkContext(object): broadcast variables on that cluster. """ - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - _takePartition = jvm.PythonRDD.takePartition + _gateway = None + _jvm = None + _writeIteratorToPickleFile = None + _takePartition = None + _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -41,6 +48,18 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self + if not SparkContext._gateway: + SparkContext._gateway = launch_gateway() + SparkContext._jvm = SparkContext._gateway.jvm + SparkContext._writeIteratorToPickleFile = \ + SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._takePartition = \ + SparkContext._jvm.PythonRDD.takePartition self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -48,10 +67,18 @@ class SparkContext(object): self.batchSize = batchSize # -1 represents a unlimited batch size # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array = self._gateway.new_array(self._jvm.String, 0) + self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -62,6 +89,13 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) + + # Create a temporary directory inside spark.local.dir: + local_dir = self._jvm.spark.Utils.getLocalDir() + self._temp_dir = \ + self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): @@ -72,15 +106,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ @@ -90,14 +129,14 @@ class SparkContext(object): # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) if self.batchSize != 1: c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile + jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): @@ -110,6 +149,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -129,12 +172,48 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an L{Accumulator} with the given initial value, using a given + L{AccumulatorParam} helper object to define how to add values of the + data type if provided. Default AccumulatorParams are used for integers + and floating-point numbers if you do not provide one. For other types, + a custom AccumulatorParam can be used. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -155,5 +234,31 @@ class SparkContext(object): """ self.addFile(path) filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import atexit + import doctest + import tempfile + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/files.py b/python/pyspark/files.py new file mode 100644 index 0000000000..001b7a28b6 --- /dev/null +++ b/python/pyspark/files.py @@ -0,0 +1,38 @@ +import os + + +class SparkFiles(object): + """ + Resolves paths to files added through + L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}. + + SparkFiles contains only classmethods; users should not create SparkFiles + instances. + """ + + _root_directory = None + _is_running_on_worker = False + _sc = None + + def __init__(self): + raise NotImplementedError("Do not construct SparkFiles objects") + + @classmethod + def get(cls, filename): + """ + Get the absolute path of a file added through C{SparkContext.addFile()}. + """ + path = os.path.join(SparkFiles.getRootDirectory(), filename) + return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc._jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..41ea6e6e14 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1,4 +1,3 @@ -import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict @@ -32,7 +31,9 @@ class RDD(object): def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -49,6 +50,34 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it + on a file will require recomputation. + """ + self.is_checkpointed = True + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): @@ -234,12 +263,8 @@ class RDD(object): # Transferring lots of data through Py4J can be slow because # socket.readline() is inefficient. Instead, we'll dump the data to a # file and read it back. - tempFile = NamedTemporaryFile(delete=False) + tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - def clean_up_file(): - try: os.unlink(tempFile.name) - except: pass - atexit.register(clean_up_file) self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: @@ -377,7 +402,7 @@ class RDD(object): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -497,7 +522,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +539,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True - pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx._jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, @@ -662,7 +691,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -675,6 +704,7 @@ class PipelinedRDD(RDD): self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None @@ -695,18 +725,21 @@ class PipelinedRDD(RDD): pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx.gateway._gateway_client) + self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() env = copy.copy(self.ctx.environment) env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") - env = MapConverter().convert(env, self.ctx.gateway._gateway_client) - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + env = MapConverter().convert(env, self.ctx._gateway._gateway_client) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) + def _test(): import doctest diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..54ff1bf8e7 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,9 +1,10 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os +import pyspark from pyspark.context import SparkContext @@ -14,4 +15,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..6a1962d267 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,121 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import os +import shutil +import sys +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext +from pyspark.files import SparkFiles +from pyspark.java_gateway import SPARK_HOME + + +class PySparkTestCase(unittest.TestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name , batchSize=2) + + def tearDown(self): + self.sc.stop() + sys.path = self._old_sys_path + # To avoid Akka rebinding to the same port, since it doesn't unbind + # immediately on shutdown + self.sc._jvm.System.clearProperty("spark.driver.port") + + +class TestCheckpoint(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(self.checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + self.assertEquals([1, 2, 3, 4], recovered.collect()) + + +class TestAddFile(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + def func(x): + from userlibrary import UserClass + return UserClass().hello() + self.assertRaises(Exception, + self.sc.parallelize(range(2)).map(func).first) + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + + +class TestIO(PySparkTestCase): + + def test_stdout_redirection(self): + import subprocess + def func(x): + subprocess.check_call('ls', shell=True) + self.sc.parallelize([1]).foreach(func) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..812e7a9da5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,19 +1,23 @@ """ Worker that receives input from Piped RDD. """ +import os import sys +import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.files import SparkFiles +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. -old_stdout = sys.stdout -sys.stdout = sys.stderr +old_stdout = os.fdopen(os.dup(1), 'w') +os.dup2(2, 1) def load_obj(): @@ -22,6 +26,10 @@ def load_obj(): def main(): split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) @@ -34,8 +42,17 @@ def main(): else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) - for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + try: + for obj in func(split_index, iterator): + write_with_length(dumps(obj), old_stdout) + except Exception as e: + write_int(-2, old_stdout) + write_with_length(traceback.format_exc(), old_stdout) + sys.exit(-1) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index fcdd1e27a7..a3a9ff5dcb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -8,9 +8,18 @@ FAILED=0 $FWDIR/pyspark pyspark/rdd.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark pyspark/context.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m doctest pyspark/accumulators.py +FAILED=$(($?||$FAILED)) + +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000..980a0d5f19 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World! diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py new file mode 100755 index 0000000000..5bb6f5009f --- /dev/null +++ b/python/test_support/userlibrary.py @@ -0,0 +1,7 @@ +""" +Used to test shipping of code depenencies with SparkContext.addPyFile(). +""" + +class UserClass(object): + def hello(self): + return "Hello World!" |