aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/Utils.scala11
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala83
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala2
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala3
-rw-r--r--docs/ec2-scripts.md4
-rw-r--r--docs/python-programming-guide.md1
-rw-r--r--examples/src/main/scala/spark/examples/SparkALS.scala59
-rwxr-xr-xpyspark7
-rwxr-xr-xpython/examples/als.py71
-rw-r--r--python/pyspark/__init__.py4
-rw-r--r--python/pyspark/accumulators.py178
-rw-r--r--python/pyspark/context.py38
-rw-r--r--python/pyspark/rdd.py2
-rw-r--r--python/pyspark/serializers.py7
-rw-r--r--python/pyspark/shell.py4
-rw-r--r--python/pyspark/worker.py7
-rwxr-xr-xpython/run-tests3
-rwxr-xr-xrun7
18 files changed, 424 insertions, 67 deletions
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index d08921b25f..b3421df27c 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -134,7 +134,7 @@ private object Utils extends Logging {
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
- val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))
+ val tempDir = getLocalDir
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
@@ -205,6 +205,15 @@ private object Utils extends Logging {
}
/**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ }
+
+ /**
* Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 0138b22d38..89f7c316dc 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,7 +1,8 @@
package spark.api.python
import java.io._
-import java.util.{List => JList}
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
-import java.util
private[spark] class PythonRDD[T: ClassManifest](
- parent: RDD[T],
- command: Seq[String],
- envVars: java.util.Map[String, String],
- preservePartitoning: Boolean,
- pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]])
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
- broadcastVars)
+ broadcastVars, accumulator)
override def getSplits = parent.splits
@@ -91,18 +93,30 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
- def next() = {
+ def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
- private def read() = {
+ private def read(): Array[Byte] = {
try {
val length = stream.readInt()
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
+ if (length != -1) {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } else {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index fef264aab1..8e490e6bad 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -95,7 +95,7 @@ private object HttpBroadcast extends Logging {
}
private def createServer() {
- broadcastDir = Utils.createTempDir()
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 915f71ba9f..a29bf974d2 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend(
with ExecutorBackend
with Logging {
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
var master: ActorRef = null
override def preStart() {
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 6e1f7fd3b1..931b7a66bd 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -96,7 +96,9 @@ permissions on your private key file, you can run `launch` with the
`spark-ec2` to attach a persistent EBS volume to each node for
storing the persistent HDFS.
- Finally, if you get errors while running your jobs, look at the slave's logs
- for that job using the Mesos web UI (`http://<master-hostname>:8080`).
+ for that job inside of the Mesos work directory (/mnt/mesos-work). You can
+ also view the status of the cluster using the Mesos web UI
+ (`http://<master-hostname>:8080`).
# Configuration
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 78ef310a00..a840b9b34b 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -16,7 +16,6 @@ There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of different types.
* PySpark does not currently support the following Spark features:
- - Accumulators
- Special functions on RDDs of doubles, such as `mean` and `stdev`
- `lookup`
- `persist` at storage levels other than `MEMORY_ONLY`
diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala
index fb28e2c932..5e01885dbb 100644
--- a/examples/src/main/scala/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/spark/examples/SparkALS.scala
@@ -7,6 +7,7 @@ import cern.jet.math._
import cern.colt.matrix._
import cern.colt.matrix.linalg._
import spark._
+import scala.Option
object SparkALS {
// Parameters set through command line arguments
@@ -42,7 +43,7 @@ object SparkALS {
return sqrt(sumSqs / (M * U))
}
- def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val U = us.size
@@ -68,50 +69,30 @@ object SparkALS {
return solved2D.viewColumn(0)
}
- def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val M = ms.size
- val F = ms(0).size
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
- // For each movie that the user rated
- for (i <- 0 until M) {
- val m = ms(i)
- // Add m * m^t to XtX
- blas.dger(1, m, m, XtX)
- // Add m * rating to Xty
- blas.daxpy(R.get(i, j), m, Xty)
- }
- // Add regularization coefs to diagonal terms
- for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
- }
- // Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
- }
-
def main(args: Array[String]) {
var host = ""
var slices = 0
- args match {
- case Array(m, u, f, iters, slices_, host_) => {
- M = m.toInt
- U = u.toInt
- F = f.toInt
- ITERATIONS = iters.toInt
- slices = slices_.toInt
- host = host_
+
+ (0 to 5).map(i => {
+ i match {
+ case a if a < args.length => Some(args(a))
+ case _ => None
+ }
+ }).toArray match {
+ case Array(host_, m, u, f, iters, slices_) => {
+ host = host_ getOrElse "local"
+ M = (m getOrElse "100").toInt
+ U = (u getOrElse "500").toInt
+ F = (f getOrElse "10").toInt
+ ITERATIONS = (iters getOrElse "5").toInt
+ slices = (slices_ getOrElse "2").toInt
}
case _ => {
- System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <master>")
+ System.err.println("Usage: SparkALS [<master> <M> <U> <F> <iters> <slices>]")
System.exit(1)
}
}
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
val spark = new SparkContext(host, "SparkALS")
val R = generateR()
@@ -127,11 +108,11 @@ object SparkALS {
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
- .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
+ .map(i => update(i, msc.value(i), usc.value, Rc.value))
.toArray
msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
- .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
+ .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value)))
.toArray
usc = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
diff --git a/pyspark b/pyspark
index 9e89d51ba2..ab7f4f50c0 100755
--- a/pyspark
+++ b/pyspark
@@ -6,6 +6,13 @@ FWDIR="$(cd `dirname $0`; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+# Exit if the user hasn't compiled Spark
+if [ ! -e "$SPARK_HOME/repl/target" ]; then
+ echo "Failed to find Spark classes in $SPARK_HOME/repl/target" >&2
+ echo "You need to compile Spark before running this program" >&2
+ exit 1
+fi
+
# Load environment variables from conf/spark-env.sh, if it exists
if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
diff --git a/python/examples/als.py b/python/examples/als.py
new file mode 100755
index 0000000000..010f80097f
--- /dev/null
+++ b/python/examples/als.py
@@ -0,0 +1,71 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from os.path import realpath
+import sys
+
+import numpy as np
+from numpy.random import rand
+from numpy import matrix
+from pyspark import SparkContext
+
+LAMBDA = 0.01 # regularization
+np.random.seed(42)
+
+def rmse(R, ms, us):
+ diff = R - ms * us.T
+ return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
+
+def update(i, vec, mat, ratings):
+ uu = mat.shape[0]
+ ff = mat.shape[1]
+ XtX = matrix(np.zeros((ff, ff)))
+ Xty = np.zeros((ff, 1))
+
+ for j in range(uu):
+ v = mat[j, :]
+ XtX += v.T * v
+ Xty += v.T * ratings[i, j]
+ XtX += np.eye(ff, ff) * LAMBDA * uu
+ return np.linalg.solve(XtX, Xty)
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, \
+ "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
+ M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
+ U = int(sys.argv[3]) if len(sys.argv) > 3 else 500
+ F = int(sys.argv[4]) if len(sys.argv) > 4 else 10
+ ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5
+ slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2
+
+ print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
+ (M, U, F, ITERATIONS, slices)
+
+ R = matrix(rand(M, F)) * matrix(rand(U, F).T)
+ ms = matrix(rand(M ,F))
+ us = matrix(rand(U, F))
+
+ Rb = sc.broadcast(R)
+ msb = sc.broadcast(ms)
+ usb = sc.broadcast(us)
+
+ for i in range(ITERATIONS):
+ ms = sc.parallelize(range(M), slices) \
+ .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
+ .collect()
+ ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
+ # a 3-d array, we take the first 2 dims for the matrix
+ msb = sc.broadcast(ms)
+
+ us = sc.parallelize(range(U), slices) \
+ .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \
+ .collect()
+ us = matrix(np.array(us)[:, :, 0])
+ usb = sc.broadcast(us)
+
+ error = rmse(R, ms, us)
+ print "Iteration %d:" % i
+ print "\nRMSE: %5.4f\n" % error
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index c595ae0842..00666bc0a3 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -7,6 +7,10 @@ Public classes:
Main entry point for Spark functionality.
- L{RDD<pyspark.rdd.RDD>}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ - L{Broadcast<pyspark.broadcast.Broadcast>}
+ A broadcast variable that gets reused across tasks.
+ - L{Accumulator<pyspark.accumulators.Accumulator>}
+ An "add-only" shared variable that tasks can only add values to.
"""
import sys
import os
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
new file mode 100644
index 0000000000..c00c3a37af
--- /dev/null
+++ b/python/pyspark/accumulators.py
@@ -0,0 +1,178 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> a = sc.accumulator(1)
+>>> a.value
+1
+>>> a.value = 2
+>>> a.value
+2
+>>> a += 5
+>>> a.value
+7
+
+>>> rdd = sc.parallelize([1,2,3])
+>>> def f(x):
+... global a
+... a += x
+>>> rdd.foreach(f)
+>>> a.value
+13
+
+>>> class VectorAccumulatorParam(object):
+... def zero(self, value):
+... return [0.0] * len(value)
+... def addInPlace(self, val1, val2):
+... for i in xrange(len(val1)):
+... val1[i] += val2[i]
+... return val1
+>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
+>>> va.value
+[1.0, 2.0, 3.0]
+>>> def g(x):
+... global va
+... va += [x] * 3
+>>> rdd.foreach(g)
+>>> va.value
+[7.0, 8.0, 9.0]
+
+>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> def h(x):
+... global a
+... a.value = 7
+>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Exception:...
+"""
+
+import struct
+import SocketServer
+import threading
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import read_int, read_with_length, load_pickle
+
+
+# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
+# the local accumulator updates back to the driver program at the end of a task.
+_accumulatorRegistry = {}
+
+
+def _deserialize_accumulator(aid, zero_value, accum_param):
+ from pyspark.accumulators import _accumulatorRegistry
+ accum = Accumulator(aid, zero_value, accum_param)
+ accum._deserialized = True
+ _accumulatorRegistry[aid] = accum
+ return accum
+
+
+class Accumulator(object):
+ """
+ A shared variable that can be accumulated, i.e., has a commutative and associative "add"
+ operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
+ operator, but only the driver program is allowed to access its value, using C{value}.
+ Updates from the workers get propagated automatically to the driver program.
+
+ While C{SparkContext} supports accumulators for primitive data types like C{int} and
+ C{float}, users can also define accumulators for custom types by providing a custom
+ C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest
+ of this module for an example.
+ """
+
+ def __init__(self, aid, value, accum_param):
+ """Create a new Accumulator with a given initial value and AccumulatorParam object"""
+ from pyspark.accumulators import _accumulatorRegistry
+ self.aid = aid
+ self.accum_param = accum_param
+ self._value = value
+ self._deserialized = False
+ _accumulatorRegistry[aid] = self
+
+ def __reduce__(self):
+ """Custom serialization; saves the zero value from our AccumulatorParam"""
+ param = self.accum_param
+ return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
+
+ @property
+ def value(self):
+ """Get the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ """Sets the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ self._value = value
+
+ def __iadd__(self, term):
+ """The += operator; adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+ return self
+
+ def __str__(self):
+ return str(self._value)
+
+
+class AddingAccumulatorParam(object):
+ """
+ An AccumulatorParam that uses the + operators to add values. Designed for simple types
+ such as integers, floats, and lists. Requires the zero value for the underlying type
+ as a parameter.
+ """
+
+ def __init__(self, zero_value):
+ self.zero_value = zero_value
+
+ def zero(self, value):
+ return self.zero_value
+
+ def addInPlace(self, value1, value2):
+ value1 += value2
+ return value1
+
+
+# Singleton accumulator params for some standard types
+INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
+DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
+COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
+
+
+class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+ def handle(self):
+ from pyspark.accumulators import _accumulatorRegistry
+ num_updates = read_int(self.rfile)
+ for _ in range(num_updates):
+ (aid, update) = load_pickle(read_with_length(self.rfile))
+ _accumulatorRegistry[aid] += update
+ # Write a byte in acknowledgement
+ self.wfile.write(struct.pack("!b", 1))
+
+
+def _start_update_server():
+ """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
+ server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
+ thread = threading.Thread(target=server.serve_forever)
+ thread.daemon = True
+ thread.start()
+ return server
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index e486f206b0..1e2f845f9c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -2,6 +2,8 @@ import os
import atexit
from tempfile import NamedTemporaryFile
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
@@ -22,6 +24,7 @@ class SparkContext(object):
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
+ _next_accum_id = 0
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -52,6 +55,14 @@ class SparkContext(object):
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
+ # Create a single Accumulator in Java that we'll send all our updates through;
+ # they will be passed back to us through a TCP server
+ self._accumulatorServer = accumulators._start_update_server()
+ (host, port) = self._accumulatorServer.server_address
+ self._javaAccumulator = self._jsc.accumulator(
+ self.jvm.java.util.ArrayList(),
+ self.jvm.PythonAccumulatorParam(host, port))
+
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
@@ -74,6 +85,8 @@ class SparkContext(object):
def __del__(self):
if self._jsc:
self._jsc.stop()
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
def stop(self):
"""
@@ -129,6 +142,31 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
+ def accumulator(self, value, accum_param=None):
+ """
+ Create an C{Accumulator} with the given initial value, using a given
+ AccumulatorParam helper object to define how to add values of the data
+ type if provided. Default AccumulatorParams are used for integers and
+ floating-point numbers if you do not provide one. For other types, the
+ AccumulatorParam must implement two methods:
+ - C{zero(value)}: provide a "zero value" for the type, compatible in
+ dimensions with the provided C{value} (e.g., a zero vector).
+ - C{addInPlace(val1, val2)}: add two values of the accumulator's data
+ type, returning a new value; for efficiency, can also update C{val1}
+ in place and return it.
+ """
+ if accum_param == None:
+ if isinstance(value, int):
+ accum_param = accumulators.INT_ACCUMULATOR_PARAM
+ elif isinstance(value, float):
+ accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+ elif isinstance(value, complex):
+ accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+ else:
+ raise Exception("No default accumulator param for type %s" % type(value))
+ SparkContext._next_accum_id += 1
+ return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1d36da42b0..d705f0f9e1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -703,7 +703,7 @@ class PipelinedRDD(RDD):
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, class_manifest)
+ broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9a5151ea00..115cf28cc2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -52,8 +52,13 @@ def read_int(stream):
raise EOFError
return struct.unpack("!i", length)[0]
+
+def write_int(value, stream):
+ stream.write(struct.pack("!i", value))
+
+
def write_with_length(obj, stream):
- stream.write(struct.pack("!i", len(obj)))
+ write_int(len(obj), stream)
stream.write(obj)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 7e6ad3aa76..f6328c561f 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -1,7 +1,7 @@
"""
An interactive shell.
-This fle is designed to be launched as a PYTHONSTARTUP script.
+This file is designed to be launched as a PYTHONSTARTUP script.
"""
import os
from pyspark.context import SparkContext
@@ -14,4 +14,4 @@ print "Spark context avaiable as sc."
# which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
if _pythonstartup and os.path.isfile(_pythonstartup):
- execfile(_pythonstartup)
+ execfile(_pythonstartup)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3d792bbaa2..b2b9288089 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -5,9 +5,10 @@ import sys
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
+from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import write_with_length, read_with_length, \
+from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
@@ -36,6 +37,10 @@ def main():
iterator = read_from_pickle_file(sys.stdin)
for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout)
+ # Mark the beginning of the accumulators section of the output
+ write_int(-1, old_stdout)
+ for aid, accum in _accumulatorRegistry.items():
+ write_with_length(dump_pickle((aid, accum._value)), old_stdout)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index fcdd1e27a7..32470911f9 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -11,6 +11,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m doctest pyspark/accumulators.py
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
diff --git a/run b/run
index 7acbf3f3ff..060856007f 100755
--- a/run
+++ b/run
@@ -66,6 +66,13 @@ BAGEL_DIR="$FWDIR/bagel"
STREAMING_DIR="$FWDIR/streaming"
PYSPARK_DIR="$FWDIR/python"
+# Exit if the user hasn't compiled Spark
+if [ ! -e "$REPL_DIR/target" ]; then
+ echo "Failed to find Spark classes in $REPL_DIR/target" >&2
+ echo "You need to compile Spark before running this program" >&2
+ exit 1
+fi
+
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
CLASSPATH+=":$FWDIR/conf"