diff options
-rw-r--r-- | core/src/main/scala/spark/Utils.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 83 | ||||
-rw-r--r-- | core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala | 3 | ||||
-rw-r--r-- | docs/ec2-scripts.md | 4 | ||||
-rw-r--r-- | docs/python-programming-guide.md | 1 | ||||
-rw-r--r-- | examples/src/main/scala/spark/examples/SparkALS.scala | 59 | ||||
-rwxr-xr-x | pyspark | 7 | ||||
-rwxr-xr-x | python/examples/als.py | 71 | ||||
-rw-r--r-- | python/pyspark/__init__.py | 4 | ||||
-rw-r--r-- | python/pyspark/accumulators.py | 178 | ||||
-rw-r--r-- | python/pyspark/context.py | 38 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 2 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 7 | ||||
-rw-r--r-- | python/pyspark/shell.py | 4 | ||||
-rw-r--r-- | python/pyspark/worker.py | 7 | ||||
-rwxr-xr-x | python/run-tests | 3 | ||||
-rwxr-xr-x | run | 7 |
18 files changed, 424 insertions, 67 deletions
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index d08921b25f..b3421df27c 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -134,7 +134,7 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last - val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempDir = getLocalDir val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -205,6 +205,15 @@ private object Utils extends Logging { } /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 0138b22d38..89f7c316dc 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def getSplits = parent.splits @@ -91,18 +93,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index fef264aab1..8e490e6bad 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -95,7 +95,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 6e1f7fd3b1..931b7a66bd 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -96,7 +96,9 @@ permissions on your private key file, you can run `launch` with the `spark-ec2` to attach a persistent EBS volume to each node for storing the persistent HDFS. - Finally, if you get errors while running your jobs, look at the slave's logs - for that job using the Mesos web UI (`http://<master-hostname>:8080`). + for that job inside of the Mesos work directory (/mnt/mesos-work). You can + also view the status of the cluster using the Mesos web UI + (`http://<master-hostname>:8080`). # Configuration diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 78ef310a00..a840b9b34b 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -16,7 +16,6 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `persist` at storage levels other than `MEMORY_ONLY` diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index fb28e2c932..5e01885dbb 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -7,6 +7,7 @@ import cern.jet.math._ import cern.colt.matrix._ import cern.colt.matrix.linalg._ import spark._ +import scala.Option object SparkALS { // Parameters set through command line arguments @@ -42,7 +43,7 @@ object SparkALS { return sqrt(sumSqs / (M * U)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], R: DoubleMatrix2D) : DoubleMatrix1D = { val U = us.size @@ -68,50 +69,30 @@ object SparkALS { return solved2D.viewColumn(0) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val M = ms.size - val F = ms(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - def main(args: Array[String]) { var host = "" var slices = 0 - args match { - case Array(m, u, f, iters, slices_, host_) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - slices = slices_.toInt - host = host_ + + (0 to 5).map(i => { + i match { + case a if a < args.length => Some(args(a)) + case _ => None + } + }).toArray match { + case Array(host_, m, u, f, iters, slices_) => { + host = host_ getOrElse "local" + M = (m getOrElse "100").toInt + U = (u getOrElse "500").toInt + F = (f getOrElse "10").toInt + ITERATIONS = (iters getOrElse "5").toInt + slices = (slices_ getOrElse "2").toInt } case _ => { - System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <master>") + System.err.println("Usage: SparkALS [<master> <M> <U> <F> <iters> <slices>]") System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val spark = new SparkContext(host, "SparkALS") val R = generateR() @@ -127,11 +108,11 @@ object SparkALS { for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) + .map(i => update(i, msc.value(i), usc.value, Rc.value)) .toArray msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) + .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value))) .toArray usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) @@ -6,6 +6,13 @@ FWDIR="$(cd `dirname $0`; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +# Exit if the user hasn't compiled Spark +if [ ! -e "$SPARK_HOME/repl/target" ]; then + echo "Failed to find Spark classes in $SPARK_HOME/repl/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Load environment variables from conf/spark-env.sh, if it exists if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh diff --git a/python/examples/als.py b/python/examples/als.py new file mode 100755 index 0000000000..010f80097f --- /dev/null +++ b/python/examples/als.py @@ -0,0 +1,71 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from os.path import realpath +import sys + +import numpy as np +from numpy.random import rand +from numpy import matrix +from pyspark import SparkContext + +LAMBDA = 0.01 # regularization +np.random.seed(42) + +def rmse(R, ms, us): + diff = R - ms * us.T + return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + +def update(i, vec, mat, ratings): + uu = mat.shape[0] + ff = mat.shape[1] + XtX = matrix(np.zeros((ff, ff))) + Xty = np.zeros((ff, 1)) + + for j in range(uu): + v = mat[j, :] + XtX += v.T * v + Xty += v.T * ratings[i, j] + XtX += np.eye(ff, ff) * LAMBDA * uu + return np.linalg.solve(XtX, Xty) + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, \ + "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)]) + M = int(sys.argv[2]) if len(sys.argv) > 2 else 100 + U = int(sys.argv[3]) if len(sys.argv) > 3 else 500 + F = int(sys.argv[4]) if len(sys.argv) > 4 else 10 + ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5 + slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2 + + print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \ + (M, U, F, ITERATIONS, slices) + + R = matrix(rand(M, F)) * matrix(rand(U, F).T) + ms = matrix(rand(M ,F)) + us = matrix(rand(U, F)) + + Rb = sc.broadcast(R) + msb = sc.broadcast(ms) + usb = sc.broadcast(us) + + for i in range(ITERATIONS): + ms = sc.parallelize(range(M), slices) \ + .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \ + .collect() + ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being + # a 3-d array, we take the first 2 dims for the matrix + msb = sc.broadcast(ms) + + us = sc.parallelize(range(U), slices) \ + .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \ + .collect() + us = matrix(np.array(us)[:, :, 0]) + usb = sc.broadcast(us) + + error = rmse(R, ms, us) + print "Iteration %d:" % i + print "\nRMSE: %5.4f\n" % error diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD<pyspark.rdd.RDD>} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast<pyspark.broadcast.Broadcast>} + A broadcast variable that gets reused across tasks. + - L{Accumulator<pyspark.accumulators.Accumulator>} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..c00c3a37af --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,178 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + """ + A shared variable that can be accumulated, i.e., has a commutative and associative "add" + operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} + operator, but only the driver program is allowed to access its value, using C{value}. + Updates from the workers get propagated automatically to the driver program. + + While C{SparkContext} supports accumulators for primitive data types like C{int} and + C{float}, users can also define accumulators for custom types by providing a custom + C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest + of this module for an example. + """ + + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index fcdd1e27a7..32470911f9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -11,6 +11,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m doctest pyspark/accumulators.py +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." @@ -66,6 +66,13 @@ BAGEL_DIR="$FWDIR/bagel" STREAMING_DIR="$FWDIR/streaming" PYSPARK_DIR="$FWDIR/python" +# Exit if the user hasn't compiled Spark +if [ ! -e "$REPL_DIR/target" ]; then + echo "Failed to find Spark classes in $REPL_DIR/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Build up classpath CLASSPATH="$SPARK_CLASSPATH" CLASSPATH+=":$FWDIR/conf" |