aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala39
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala248
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala2
-rw-r--r--docs/README.md8
-rwxr-xr-xdocs/_layouts/global.html11
-rw-r--r--docs/_plugins/copy_api_dirs.rb17
-rw-r--r--docs/api.md1
-rw-r--r--docs/index.md15
-rw-r--r--docs/python-programming-guide.md111
-rw-r--r--docs/quick-start.md40
-rwxr-xr-xpyspark32
-rw-r--r--python/.gitignore2
-rw-r--r--python/epydoc.conf19
-rw-r--r--python/examples/kmeans.py54
-rwxr-xr-xpython/examples/logistic_regression.py57
-rw-r--r--python/examples/pi.py21
-rw-r--r--python/examples/transitive_closure.py50
-rw-r--r--python/examples/wordcount.py19
-rw-r--r--python/lib/PY4J_LICENSE.txt27
-rw-r--r--python/lib/PY4J_VERSION.txt1
-rw-r--r--python/lib/py4j0.7.eggbin0 -> 191756 bytes
-rw-r--r--python/lib/py4j0.7.jarbin0 -> 103286 bytes
-rw-r--r--python/pyspark/__init__.py20
-rw-r--r--python/pyspark/broadcast.py48
-rw-r--r--python/pyspark/cloudpickle.py974
-rw-r--r--python/pyspark/context.py159
-rw-r--r--python/pyspark/java_gateway.py38
-rw-r--r--python/pyspark/join.py92
-rw-r--r--python/pyspark/rdd.py723
-rw-r--r--python/pyspark/serializers.py78
-rw-r--r--python/pyspark/shell.py17
-rw-r--r--python/pyspark/worker.py42
-rwxr-xr-xpython/run-tests26
-rwxr-xr-xrun4
-rw-r--r--run2.cmd2
35 files changed, 2985 insertions, 12 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..648d9402b0
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,39 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ */
+private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+
+ override def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
+ }
+ else {
+ val hashCode = {
+ if (key.isInstanceOf[Array[Byte]]) {
+ Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ } else {
+ key.hashCode()
+ }
+ }
+ val mod = hashCode % numPartitions
+ if (mod < 0) {
+ mod + numPartitions
+ } else {
+ mod // Guard against negative hash codes
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
new file mode 100644
index 0000000000..f431ef28d3
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -0,0 +1,248 @@
+package spark.api.python
+
+import java.io._
+import java.util.{List => JList}
+
+import scala.collection.JavaConversions._
+import scala.io.Source
+
+import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import spark.broadcast.Broadcast
+import spark._
+import spark.rdd.PipedRDD
+import java.util
+
+
+private[spark] class PythonRDD[T: ClassManifest](
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: java.util.List[Broadcast[Array[Byte]]])
+ extends RDD[Array[Byte]](parent.context) {
+
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
+ broadcastVars)
+
+ override def splits = parent.splits
+
+ override val dependencies = List(new OneToOneDependency(parent))
+
+ override val partitioner = if (preservePartitoning) parent.partitioner else None
+
+ override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
+ val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
+
+ val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+
+ for ((variable, value) <- envVars) {
+ currentEnvVars.put(variable, value)
+ }
+
+ val proc = pb.start()
+ val env = SparkEnv.get
+
+ // Start a thread to print the process's stderr to ours
+ new Thread("stderr reader for " + command) {
+ override def run() {
+ for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+
+ // Start a thread to feed the process input from our parent's iterator
+ new Thread("stdin writer for " + command) {
+ override def run() {
+ SparkEnv.set(env)
+ val out = new PrintWriter(proc.getOutputStream)
+ val dOut = new DataOutputStream(proc.getOutputStream)
+ // Split index
+ dOut.writeInt(split.index)
+ // Broadcast variables
+ dOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dOut.writeLong(broadcast.id)
+ dOut.writeInt(broadcast.value.length)
+ dOut.write(broadcast.value)
+ dOut.flush()
+ }
+ // Serialized user code
+ for (elem <- command) {
+ out.println(elem)
+ }
+ out.flush()
+ // Data values
+ for (elem <- parent.iterator(split, context)) {
+ PythonRDD.writeAsPickle(elem, dOut)
+ }
+ dOut.flush()
+ out.flush()
+ proc.getOutputStream.close()
+ }
+ }.start()
+
+ // Return an iterator that read lines from the process's stdout
+ val stream = new DataInputStream(proc.getInputStream)
+ return new Iterator[Array[Byte]] {
+ def next() = {
+ val obj = _nextObj
+ _nextObj = read()
+ obj
+ }
+
+ private def read() = {
+ try {
+ val length = stream.readInt()
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } catch {
+ case eof: EOFException => {
+ val exitStatus = proc.waitFor()
+ if (exitStatus != 0) {
+ throw new Exception("Subprocess exited with status " + exitStatus)
+ }
+ new Array[Byte](0)
+ }
+ case e => throw e
+ }
+ }
+
+ var _nextObj = read()
+
+ def hasNext = _nextObj.length != 0
+ }
+ }
+
+ val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/**
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
+ * This is used by PySpark's shuffle operations.
+ */
+private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
+ RDD[(Array[Byte], Array[Byte])](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split, context: TaskContext) =
+ prev.iterator(split, context).grouped(2).map {
+ case Seq(a, b) => (a, b)
+ case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ }
+ val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+private[spark] object PythonRDD {
+
+ /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
+ def stripPickle(arr: Array[Byte]) : Array[Byte] = {
+ arr.slice(2, arr.length - 1)
+ }
+
+ /**
+ * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+ * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+ * followed by the pickled data.
+ *
+ * Pickle module:
+ *
+ * http://docs.python.org/2/library/pickle.html
+ *
+ * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
+ *
+ * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
+ * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
+ *
+ * @param elem the object to write
+ * @param dOut a data output stream
+ */
+ def writeAsPickle(elem: Any, dOut: DataOutputStream) {
+ if (elem.isInstanceOf[Array[Byte]]) {
+ val arr = elem.asInstanceOf[Array[Byte]]
+ dOut.writeInt(arr.length)
+ dOut.write(arr)
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+ val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
+ dOut.writeByte(Pickle.TUPLE2)
+ dOut.writeByte(Pickle.STOP)
+ } else if (elem.isInstanceOf[String]) {
+ // For uniformity, strings are wrapped into Pickles.
+ val s = elem.asInstanceOf[String].getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(Pickle.BINUNICODE)
+ dOut.writeInt(Integer.reverseBytes(s.length))
+ dOut.write(s)
+ dOut.writeByte(Pickle.STOP)
+ } else {
+ throw new Exception("Unexpected RDD type")
+ }
+ }
+
+ def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ JavaRDD[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
+ case e => throw e
+ }
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ }
+
+ def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ val file = new DataOutputStream(new FileOutputStream(filename))
+ for (item <- items) {
+ writeAsPickle(item, file)
+ }
+ file.close()
+ }
+
+ def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+}
+
+private object Pickle {
+ val PROTO: Byte = 0x80.toByte
+ val TWO: Byte = 0x02.toByte
+ val BINUNICODE: Byte = 'X'
+ val STOP: Byte = '.'
+ val TUPLE2: Byte = 0x86.toByte
+ val EMPTY_LIST: Byte = ']'
+ val MARK: Byte = '('
+ val APPENDS: Byte = 'e'
+}
+
+private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
+ Array[Byte]), Array[Byte]] {
+ override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
+}
+
+private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
+ override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 6055bfd045..2ffe7f741d 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._
-abstract class Broadcast[T](id: Long) extends Serializable {
+abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
diff --git a/docs/README.md b/docs/README.md
index 092153070e..887f407f18 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -25,10 +25,12 @@ To mark a block of code in your markdown to be syntax highlighted by jekyll duri
// supported languages too.
{% endhighlight %}
-## Scaladoc
+## API Docs (Scaladoc and Epydoc)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
-When you run `jekyll` in the docs directory, it will also copy over the scala doc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc.
+Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the SPARK_PROJECT_ROOT/pyspark directory.
-NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`.
+When you run `jekyll` in the docs directory, it will also copy over the scaladoc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs using [epydoc](http://epydoc.sourceforge.net/).
+
+NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. Similarly, `SKIP_EPYDOC=1 jekyll` will skip PySpark API doc generation.
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 7244ab6fc9..9804d449fc 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -47,10 +47,17 @@
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="scala-programming-guide.html">Scala</a></li>
<li><a href="java-programming-guide.html">Java</a></li>
+ <li><a href="python-programming-guide.html">Python</a></li>
+ </ul>
+ </li>
+
+ <li class="dropdown">
+ <a href="#" class="dropdown-toggle" data-toggle="dropdown">API<b class="caret"></b></a>
+ <ul class="dropdown-menu">
+ <li><a href="api/core/index.html">Scala/Java (Scaladoc)</a></li>
+ <li><a href="api/pyspark/index.html">Python (Epydoc)</a></li>
</ul>
</li>
-
- <li><a href="api/core/index.html">API (Scaladoc)</a></li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index e61c105449..c9ce589c1b 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -28,3 +28,20 @@ if ENV['SKIP_SCALADOC'] != '1'
cp_r(source + "/.", dest)
end
end
+
+if ENV['SKIP_EPYDOC'] != '1'
+ puts "Moving to python directory and building epydoc."
+ cd("../python")
+ puts `epydoc --config epydoc.conf`
+
+ puts "Moving back into docs dir."
+ cd("../docs")
+
+ puts "echo making directory pyspark"
+ mkdir_p "pyspark"
+
+ puts "cp -r ../python/docs/. api/pyspark"
+ cp_r("../python/docs/.", "api/pyspark")
+
+ cd("..")
+end
diff --git a/docs/api.md b/docs/api.md
index 43548b223c..b9c93ac5e8 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -8,3 +8,4 @@ Here you can find links to the Scaladoc generated for the Spark sbt subprojects.
- [Core](api/core/index.html)
- [Examples](api/examples/index.html)
- [Bagel](api/bagel/index.html)
+- [PySpark](api/pyspark/index.html)
diff --git a/docs/index.md b/docs/index.md
index ed9953a590..848b585333 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,11 +7,11 @@ title: Spark Overview
TODO(andyk): Rewrite to make the Java API a first class part of the story.
{% endcomment %}
-Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an
-interpreter. It provides clean, language-integrated APIs in Scala and Java, with a rich array of parallel operators. Spark can
-run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager,
+Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter.
+It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators.
+Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager,
[Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html),
-Amazon EC2, or without an independent resource manager ("standalone mode").
+Amazon EC2, or without an independent resource manager ("standalone mode").
# Downloading
@@ -59,6 +59,12 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Quick Start](quick-start.html): a quick introduction to the Spark API; start here!
* [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API
* [Java Programming Guide](java-programming-guide.html): using Spark from Java
+* [Python Programming Guide](python-programming-guide.html): using Spark from Python
+
+**API Docs:**
+
+* [Java/Scala (Scaladoc)](api/core/index.html)
+* [Python (Epydoc)](api/pyspark/index.html)
**Deployment guides:**
@@ -72,7 +78,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Configuration](configuration.html): customize Spark via its configuration system
* [Tuning Guide](tuning.html): best practices to optimize performance and memory use
-* [API Docs (Scaladoc)](api/core/index.html)
* [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark
* [Contributing to Spark](contributing-to-spark.html)
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
new file mode 100644
index 0000000000..78ef310a00
--- /dev/null
+++ b/docs/python-programming-guide.md
@@ -0,0 +1,111 @@
+---
+layout: global
+title: Python Programming Guide
+---
+
+
+The Spark Python API (PySpark) exposes most of the Spark features available in the Scala version to Python.
+To learn the basics of Spark, we recommend reading through the
+[Scala programming guide](scala-programming-guide.html) first; it should be
+easy to follow even if you don't know Scala.
+This guide will show how to use the Spark features described there in Python.
+
+# Key Differences in the Python API
+
+There are a few key differences between the Python and Scala APIs:
+
+* Python is dynamically typed, so RDDs can hold objects of different types.
+* PySpark does not currently support the following Spark features:
+ - Accumulators
+ - Special functions on RDDs of doubles, such as `mean` and `stdev`
+ - `lookup`
+ - `persist` at storage levels other than `MEMORY_ONLY`
+ - `sample`
+ - `sort`
+
+In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
+Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
+
+{% highlight python %}
+logData = sc.textFile(logFile).cache()
+errors = logData.filter(lambda s: 'ERROR' in s.split())
+{% endhighlight %}
+
+You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
+
+{% highlight python %}
+def is_error(line):
+ return 'ERROR' in line.split()
+errors = logData.filter(is_error)
+{% endhighlight %}
+
+Functions can access objects in enclosing scopes, although modifications to those objects within RDD methods will not be propagated to other tasks:
+
+{% highlight python %}
+error_keywords = ["Exception", "Error"]
+def is_error(line):
+ words = line.split()
+ return any(keyword in words for keyword in error_keywords)
+errors = logData.filter(is_error)
+{% endhighlight %}
+
+PySpark will automatically ship these functions to workers, along with any objects that they reference.
+Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers.
+The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers.
+
+# Installing and Configuring PySpark
+
+PySpark requires Python 2.6 or higher.
+PySpark jobs are executed using a standard cPython interpreter in order to support Python modules that use C extensions.
+We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/).
+By default, PySpark's scripts will run programs using `python`; an alternate Python executable may be specified by setting the `PYSPARK_PYTHON` environment variable in `conf/spark-env.sh`.
+
+All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported.
+
+Standalone PySpark jobs should be run using the `pyspark` script, which automatically configures the Java and Python environment using the settings in `conf/spark-env.sh`.
+The script automatically adds the `pyspark` package to the `PYTHONPATH`.
+
+
+# Interactive Use
+
+The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs.
+When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API:
+
+{% highlight python %}
+>>> words = sc.textFile("/usr/share/dict/words")
+>>> words.filter(lambda w: w.startswith("spar")).take(5)
+[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass']
+{% endhighlight %}
+
+By default, the `pyspark` shell creates SparkContext that runs jobs locally.
+To connect to a non-local cluster, set the `MASTER` environment variable.
+For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html):
+
+{% highlight shell %}
+$ MASTER=spark://IP:PORT ./pyspark
+{% endhighlight %}
+
+
+# Standalone Use
+
+PySpark can also be used from standalone Python scripts by creating a SparkContext in your script and running the script using `pyspark`.
+The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job.
+
+Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor:
+
+{% highlight python %}
+from pyspark import SparkContext
+sc = SparkContext("local", "Job Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg'])
+{% endhighlight %}
+
+Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
+Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
+
+# Where to Go from Here
+
+PySpark includes several sample programs using the Python API in `python/examples`.
+You can run them by passing the files to the `pyspark` script -- for example `./pyspark python/examples/wordcount.py`.
+Each example program prints usage help when run without any arguments.
+
+We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc.
+Many of the RDD method descriptions contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples.
diff --git a/docs/quick-start.md b/docs/quick-start.md
index d46dc2da3f..a4c4c9a8fb 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -6,7 +6,8 @@ title: Quick Start
* This will become a table of contents (this text will be scraped).
{:toc}
-This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala and Java. See the [programming guide](scala-programming-guide.html) for a more complete reference.
+This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala, Java, and Python.
+See the [programming guide](scala-programming-guide.html) for a more complete reference.
To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run:
@@ -240,3 +241,40 @@ Lines with a: 8422, Lines with b: 1836
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
+
+# A Standalone Job In Python
+Now we will show how to write a standalone job using the Python API (PySpark).
+
+As an example, we'll create a simple Spark job, `SimpleJob.py`:
+
+{% highlight python %}
+"""SimpleJob.py"""
+from pyspark import SparkContext
+
+logFile = "/var/log/syslog" # Should be some file on your system
+sc = SparkContext("local", "Simple job")
+logData = sc.textFile(logFile).cache()
+
+numAs = logData.filter(lambda s: 'a' in s).count()
+numBs = logData.filter(lambda s: 'b' in s).count()
+
+print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
+{% endhighlight %}
+
+
+This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file.
+Like in the Scala and Java examples, we use a SparkContext to create RDDs.
+We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference.
+For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide).
+`SimpleJob` is simple enough that we do not need to specify any code dependencies.
+
+We can run this job using the `pyspark` script:
+
+{% highlight python %}
+$ cd $SPARK_HOME
+$ ./pyspark SimpleJob.py
+...
+Lines with a: 8422, Lines with b: 1836
+{% endhighlight python %}
+
+This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
diff --git a/pyspark b/pyspark
new file mode 100755
index 0000000000..9e89d51ba2
--- /dev/null
+++ b/pyspark
@@ -0,0 +1,32 @@
+#!/usr/bin/env bash
+
+# Figure out where the Scala framework is installed
+FWDIR="$(cd `dirname $0`; pwd)"
+
+# Export this as SPARK_HOME
+export SPARK_HOME="$FWDIR"
+
+# Load environment variables from conf/spark-env.sh, if it exists
+if [ -e $FWDIR/conf/spark-env.sh ] ; then
+ . $FWDIR/conf/spark-env.sh
+fi
+
+# Figure out which Python executable to use
+if [ -z "$PYSPARK_PYTHON" ] ; then
+ PYSPARK_PYTHON="python"
+fi
+export PYSPARK_PYTHON
+
+# Add the PySpark classes to the Python path:
+export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH
+
+# Load the PySpark shell.py script when ./pyspark is used interactively:
+export OLD_PYTHONSTARTUP=$PYTHONSTARTUP
+export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py
+
+# Launch with `scala` by default:
+if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then
+ export SPARK_LAUNCH_WITH_SCALA=1
+fi
+
+exec "$PYSPARK_PYTHON" "$@"
diff --git a/python/.gitignore b/python/.gitignore
new file mode 100644
index 0000000000..5c56e638f9
--- /dev/null
+++ b/python/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+docs/
diff --git a/python/epydoc.conf b/python/epydoc.conf
new file mode 100644
index 0000000000..91ac984ba2
--- /dev/null
+++ b/python/epydoc.conf
@@ -0,0 +1,19 @@
+[epydoc] # Epydoc section marker (required by ConfigParser)
+
+# Information about the project.
+name: PySpark
+url: http://spark-project.org
+
+# The list of modules to document. Modules can be named using
+# dotted names, module filenames, or package directory names.
+# This option may be repeated.
+modules: pyspark
+
+# Write html output to the directory "apidocs"
+output: html
+target: docs/
+
+private: no
+
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+ pyspark.java_gateway pyspark.examples pyspark.shell
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
new file mode 100644
index 0000000000..72cf9f88c6
--- /dev/null
+++ b/python/examples/kmeans.py
@@ -0,0 +1,54 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = np.sum((p - centers[i]) ** 2)
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ # TODO: change this after we port takeSample()
+ #kPoints = data.takeSample(False, K, 34)
+ kPoints = data.take(K)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.map(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
+ newPoints = pointStats.map(
+ lambda (x, (y, z)): (x, y / z)).collect()
+
+ tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)
diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py
new file mode 100755
index 0000000000..f13698a86f
--- /dev/null
+++ b/python/examples/logistic_regression.py
@@ -0,0 +1,57 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from collections import namedtuple
+from math import exp
+from os.path import realpath
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+N = 100000 # Number of data points
+D = 10 # Number of dimensions
+R = 0.7 # Scaling factor
+ITERATIONS = 5
+np.random.seed(42)
+
+
+DataPoint = namedtuple("DataPoint", ['x', 'y'])
+from lr import DataPoint # So that DataPoint is properly serialized
+
+
+def generateData():
+ def generatePoint(i):
+ y = -1 if i % 2 == 0 else 1
+ x = np.random.normal(size=D) + (y * R)
+ return DataPoint(x, y)
+ return [generatePoint(i) for i in range(N)]
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonLR <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ points = sc.parallelize(generateData(), slices).cache()
+
+ # Initialize w to a random value
+ w = 2 * np.random.ranf(size=D) - 1
+ print "Initial w: " + str(w)
+
+ def add(x, y):
+ x += y
+ return x
+
+ for i in range(1, ITERATIONS + 1):
+ print "On iteration %i" % i
+
+ gradient = points.map(lambda p:
+ (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
+ ).reduce(add)
+ w -= gradient
+
+ print "Final w: " + str(w)
diff --git a/python/examples/pi.py b/python/examples/pi.py
new file mode 100644
index 0000000000..127cba029b
--- /dev/null
+++ b/python/examples/pi.py
@@ -0,0 +1,21 @@
+import sys
+from random import random
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonPi <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonPi")
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ n = 100000 * slices
+ def f(_):
+ x = random() * 2 - 1
+ y = random() * 2 - 1
+ return 1 if x ** 2 + y ** 2 < 1 else 0
+ count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
+ print "Pi is roughly %f" % (4.0 * count / n)
diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py
new file mode 100644
index 0000000000..73f7f8fbaf
--- /dev/null
+++ b/python/examples/transitive_closure.py
@@ -0,0 +1,50 @@
+import sys
+from random import Random
+
+from pyspark import SparkContext
+
+numEdges = 200
+numVertices = 100
+rand = Random(42)
+
+
+def generateGraph():
+ edges = set()
+ while len(edges) < numEdges:
+ src = rand.randrange(0, numEdges)
+ dst = rand.randrange(0, numEdges)
+ if src != dst:
+ edges.add((src, dst))
+ return edges
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonTC <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonTC")
+ slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ tc = sc.parallelize(generateGraph(), slices).cache()
+
+ # Linear transitive closure: each round grows paths by one edge,
+ # by joining the graph's edges with the already-discovered paths.
+ # e.g. join the path (y, z) from the TC with the edge (x, y) from
+ # the graph to obtain the path (x, z).
+
+ # Because join() joins on keys, the edges are stored in reversed order.
+ edges = tc.map(lambda (x, y): (y, x))
+
+ oldCount = 0L
+ nextCount = tc.count()
+ while True:
+ oldCount = nextCount
+ # Perform the join, obtaining an RDD of (y, (z, x)) pairs,
+ # then project the result to obtain the new (x, z) paths.
+ new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
+ tc = tc.union(new_edges).distinct().cache()
+ nextCount = tc.count()
+ if nextCount == oldCount:
+ break
+
+ print "TC has %i edges" % tc.count()
diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py
new file mode 100644
index 0000000000..857160624b
--- /dev/null
+++ b/python/examples/wordcount.py
@@ -0,0 +1,19 @@
+import sys
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 3:
+ print >> sys.stderr, \
+ "Usage: PythonWordCount <master> <file>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonWordCount")
+ lines = sc.textFile(sys.argv[2], 1)
+ counts = lines.flatMap(lambda x: x.split(' ')) \
+ .map(lambda x: (x, 1)) \
+ .reduceByKey(add)
+ output = counts.collect()
+ for (word, count) in output:
+ print "%s : %i" % (word, count)
diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt
new file mode 100644
index 0000000000..a70279ca14
--- /dev/null
+++ b/python/lib/PY4J_LICENSE.txt
@@ -0,0 +1,27 @@
+
+Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+- Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+- Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+- The name of the author may not be used to endorse or promote products
+derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt
new file mode 100644
index 0000000000..04a0cd52a8
--- /dev/null
+++ b/python/lib/PY4J_VERSION.txt
@@ -0,0 +1 @@
+b7924aabe9c5e63f0a4d8bbd17019534c7ec014e
diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg
new file mode 100644
index 0000000000..f8a339d8ee
--- /dev/null
+++ b/python/lib/py4j0.7.egg
Binary files differ
diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar
new file mode 100644
index 0000000000..73b7ddb7d1
--- /dev/null
+++ b/python/lib/py4j0.7.jar
Binary files differ
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
new file mode 100644
index 0000000000..c595ae0842
--- /dev/null
+++ b/python/pyspark/__init__.py
@@ -0,0 +1,20 @@
+"""
+PySpark is a Python API for Spark.
+
+Public classes:
+
+ - L{SparkContext<pyspark.context.SparkContext>}
+ Main entry point for Spark functionality.
+ - L{RDD<pyspark.rdd.RDD>}
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+"""
+import sys
+import os
+sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg"))
+
+
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD
+
+
+__all__ = ["SparkContext", "RDD"]
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
new file mode 100644
index 0000000000..93876fa738
--- /dev/null
+++ b/python/pyspark/broadcast.py
@@ -0,0 +1,48 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.bid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+
+>>> large_broadcast = sc.broadcast(list(range(10000)))
+"""
+# Holds broadcasted data received from Java, keyed by its id.
+_broadcastRegistry = {}
+
+
+def _from_id(bid):
+ from pyspark.broadcast import _broadcastRegistry
+ if bid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % bid)
+ return _broadcastRegistry[bid]
+
+
+class Broadcast(object):
+ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.bid = bid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_id, (self.bid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
new file mode 100644
index 0000000000..6a7c23a069
--- /dev/null
+++ b/python/pyspark/cloudpickle.py
@@ -0,0 +1,974 @@
+"""
+This class is defined to override standard pickle functionality
+
+The goals of it follow:
+-Serialize lambdas and nested functions to compiled byte code
+-Deal with main module correctly
+-Deal with other non-serializable objects
+
+It does not include an unpickler, as standard python unpickling suffices.
+
+This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
+<http://www.picloud.com>`_.
+
+Copyright (c) 2012, Regents of the University of California.
+Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of the University of California, Berkeley nor the
+ names of its contributors may be used to endorse or promote
+ products derived from this software without specific prior written
+ permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+import operator
+import os
+import pickle
+import struct
+import sys
+import types
+from functools import partial
+import itertools
+from copy_reg import _extension_registry, _inverted_registry, _extension_cache
+import new
+import dis
+import traceback
+
+#relevant opcodes
+STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
+DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
+LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
+GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
+
+HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
+EXTENDED_ARG = chr(dis.EXTENDED_ARG)
+
+import logging
+cloudLog = logging.getLogger("Cloud.Transport")
+
+try:
+ import ctypes
+except (MemoryError, ImportError):
+ logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True)
+ ctypes = None
+ PyObject_HEAD = None
+else:
+
+ # for reading internal structures
+ PyObject_HEAD = [
+ ('ob_refcnt', ctypes.c_size_t),
+ ('ob_type', ctypes.c_void_p),
+ ]
+
+
+try:
+ from cStringIO import StringIO
+except ImportError:
+ from StringIO import StringIO
+
+# These helper functions were copied from PiCloud's util module.
+def islambda(func):
+ return getattr(func,'func_name') == '<lambda>'
+
+def xrange_params(xrangeobj):
+ """Returns a 3 element tuple describing the xrange start, step, and len
+ respectively
+
+ Note: Only guarentees that elements of xrange are the same. parameters may
+ be different.
+ e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
+ though w/ iteration
+ """
+
+ xrange_len = len(xrangeobj)
+ if not xrange_len: #empty
+ return (0,1,0)
+ start = xrangeobj[0]
+ if xrange_len == 1: #one element
+ return start, 1, 1
+ return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
+
+#debug variables intended for developer use:
+printSerialization = False
+printMemoization = False
+
+useForcedImports = True #Should I use forced imports for tracking?
+
+
+
+class CloudPickler(pickle.Pickler):
+
+ dispatch = pickle.Pickler.dispatch.copy()
+ savedForceImports = False
+ savedDjangoEnv = False #hack tro transport django environment
+
+ def __init__(self, file, protocol=None, min_size_to_save= 0):
+ pickle.Pickler.__init__(self,file,protocol)
+ self.modules = set() #set of modules needed to depickle
+ self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
+
+ def dump(self, obj):
+ # note: not thread safe
+ # minimal side-effects, so not fixing
+ recurse_limit = 3000
+ base_recurse = sys.getrecursionlimit()
+ if base_recurse < recurse_limit:
+ sys.setrecursionlimit(recurse_limit)
+ self.inject_addons()
+ try:
+ return pickle.Pickler.dump(self, obj)
+ except RuntimeError, e:
+ if 'recursion' in e.args[0]:
+ msg = """Could not pickle object as excessively deep recursion required.
+ Try _fast_serialization=2 or contact PiCloud support"""
+ raise pickle.PicklingError(msg)
+ finally:
+ new_recurse = sys.getrecursionlimit()
+ if new_recurse == recurse_limit:
+ sys.setrecursionlimit(base_recurse)
+
+ def save_buffer(self, obj):
+ """Fallback to save_string"""
+ pickle.Pickler.save_string(self,str(obj))
+ dispatch[buffer] = save_buffer
+
+ #block broken objects
+ def save_unsupported(self, obj, pack=None):
+ raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
+ dispatch[types.GeneratorType] = save_unsupported
+
+ #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
+ try:
+ slice(0,1).__reduce__()
+ except TypeError: #can't pickle -
+ dispatch[slice] = save_unsupported
+
+ #itertools objects do not pickle!
+ for v in itertools.__dict__.values():
+ if type(v) is type:
+ dispatch[v] = save_unsupported
+
+
+ def save_dict(self, obj):
+ """hack fix
+ If the dict is a global, deal with it in a special way
+ """
+ #print 'saving', obj
+ if obj is __builtins__:
+ self.save_reduce(_get_module_builtins, (), obj=obj)
+ else:
+ pickle.Pickler.save_dict(self, obj)
+ dispatch[pickle.DictionaryType] = save_dict
+
+
+ def save_module(self, obj, pack=struct.pack):
+ """
+ Save a module as an import
+ """
+ #print 'try save import', obj.__name__
+ self.modules.add(obj)
+ self.save_reduce(subimport,(obj.__name__,), obj=obj)
+ dispatch[types.ModuleType] = save_module #new type
+
+ def save_codeobject(self, obj, pack=struct.pack):
+ """
+ Save a code object
+ """
+ #print 'try to save codeobj: ', obj
+ args = (
+ obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
+ obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
+ obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
+ )
+ self.save_reduce(types.CodeType, args, obj=obj)
+ dispatch[types.CodeType] = save_codeobject #new type
+
+ def save_function(self, obj, name=None, pack=struct.pack):
+ """ Registered with the dispatch to handle all function types.
+
+ Determines what kind of function obj is (e.g. lambda, defined at
+ interactive prompt, etc) and handles the pickling appropriately.
+ """
+ write = self.write
+
+ name = obj.__name__
+ modname = pickle.whichmodule(obj, name)
+ #print 'which gives %s %s %s' % (modname, obj, name)
+ try:
+ themodule = sys.modules[modname]
+ except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
+ modname = '__main__'
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ if not self.savedDjangoEnv:
+ #hack for django - if we detect the settings module, we transport it
+ django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
+ if django_settings:
+ django_mod = sys.modules.get(django_settings)
+ if django_mod:
+ cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
+ self.savedDjangoEnv = True
+ self.modules.add(django_mod)
+ write(pickle.MARK)
+ self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
+ write(pickle.POP_MARK)
+
+
+ # if func is lambda, def'ed at prompt, is in main, or is nested, then
+ # we'll pickle the actual function object rather than simply saving a
+ # reference (as is done in default pickler), via save_function_tuple.
+ if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule == None:
+ #Force server to import modules that have been imported in main
+ modList = None
+ if themodule == None and not self.savedForceImports:
+ mainmod = sys.modules['__main__']
+ if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
+ modList = list(mainmod.___pyc_forcedImports__)
+ self.savedForceImports = True
+ self.save_function_tuple(obj, modList)
+ return
+ else: # func is nested
+ klass = getattr(themodule, name, None)
+ if klass is None or klass is not obj:
+ self.save_function_tuple(obj, [themodule])
+ return
+
+ if obj.__dict__:
+ # essentially save_reduce, but workaround needed to avoid recursion
+ self.save(_restore_attr)
+ write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ self.save(obj.__dict__)
+ write(pickle.TUPLE + pickle.REDUCE)
+ else:
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.FunctionType] = save_function
+
+ def save_function_tuple(self, func, forced_imports):
+ """ Pickles an actual func object.
+
+ A func comprises: code, globals, defaults, closure, and dict. We
+ extract and save these, injecting reducing functions at certain points
+ to recreate the func object. Keep in mind that some of these pieces
+ can contain a ref to the func itself. Thus, a naive save on these
+ pieces could trigger an infinite loop of save's. To get around that,
+ we first create a skeleton func object using just the code (this is
+ safe, since this won't contain a ref to the func), and memoize it as
+ soon as it's created. The other stuff can then be filled in later.
+ """
+ save = self.save
+ write = self.write
+
+ # save the modules (if any)
+ if forced_imports:
+ write(pickle.MARK)
+ save(_modules_to_main)
+ #print 'forced imports are', forced_imports
+
+ forced_names = map(lambda m: m.__name__, forced_imports)
+ save((forced_names,))
+
+ #save((forced_imports,))
+ write(pickle.REDUCE)
+ write(pickle.POP_MARK)
+
+ code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
+
+ save(_fill_function) # skeleton function updater
+ write(pickle.MARK) # beginning of tuple that _fill_function expects
+
+ # create a skeleton function object and memoize it
+ save(_make_skel_func)
+ save((code, len(closure), base_globals))
+ write(pickle.REDUCE)
+ self.memoize(func)
+
+ # save the rest of the func data needed by _fill_function
+ save(f_globals)
+ save(defaults)
+ save(closure)
+ save(dct)
+ write(pickle.TUPLE)
+ write(pickle.REDUCE) # applies _fill_function on the tuple
+
+ @staticmethod
+ def extract_code_globals(co):
+ """
+ Find all globals names read or written to by codeblock co
+ """
+ code = co.co_code
+ names = co.co_names
+ out_names = set()
+
+ n = len(code)
+ i = 0
+ extended_arg = 0
+ while i < n:
+ op = code[i]
+
+ i = i+1
+ if op >= HAVE_ARGUMENT:
+ oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
+ extended_arg = 0
+ i = i+2
+ if op == EXTENDED_ARG:
+ extended_arg = oparg*65536L
+ if op in GLOBAL_OPS:
+ out_names.add(names[oparg])
+ #print 'extracted', out_names, ' from ', names
+ return out_names
+
+ def extract_func_data(self, func):
+ """
+ Turn the function into a tuple of data necessary to recreate it:
+ code, globals, defaults, closure, dict
+ """
+ code = func.func_code
+
+ # extract all global ref's
+ func_global_refs = CloudPickler.extract_code_globals(code)
+ if code.co_consts: # see if nested function have any global refs
+ for const in code.co_consts:
+ if type(const) is types.CodeType and const.co_names:
+ func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const))
+ # process all variables referenced by global environment
+ f_globals = {}
+ for var in func_global_refs:
+ #Some names, such as class functions are not global - we don't need them
+ if func.func_globals.has_key(var):
+ f_globals[var] = func.func_globals[var]
+
+ # defaults requires no processing
+ defaults = func.func_defaults
+
+ def get_contents(cell):
+ try:
+ return cell.cell_contents
+ except ValueError, e: #cell is empty error on not yet assigned
+ raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
+
+
+ # process closure
+ if func.func_closure:
+ closure = map(get_contents, func.func_closure)
+ else:
+ closure = []
+
+ # save the dict
+ dct = func.func_dict
+
+ if printSerialization:
+ outvars = ['code: ' + str(code) ]
+ outvars.append('globals: ' + str(f_globals))
+ outvars.append('defaults: ' + str(defaults))
+ outvars.append('closure: ' + str(closure))
+ print 'function ', func, 'is extracted to: ', ', '.join(outvars)
+
+ base_globals = self.globals_ref.get(id(func.func_globals), {})
+ self.globals_ref[id(func.func_globals)] = base_globals
+
+ return (code, f_globals, defaults, closure, dct, base_globals)
+
+ def save_global(self, obj, name=None, pack=struct.pack):
+ write = self.write
+ memo = self.memo
+
+ if name is None:
+ name = obj.__name__
+
+ modname = getattr(obj, "__module__", None)
+ if modname is None:
+ modname = pickle.whichmodule(obj, name)
+
+ try:
+ __import__(modname)
+ themodule = sys.modules[modname]
+ except (ImportError, KeyError, AttributeError): #should never occur
+ raise pickle.PicklingError(
+ "Can't pickle %r: Module %s cannot be found" %
+ (obj, modname))
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ sendRef = True
+ typ = type(obj)
+ #print 'saving', obj, typ
+ try:
+ try: #Deal with case when getattribute fails with exceptions
+ klass = getattr(themodule, name)
+ except (AttributeError):
+ if modname == '__builtin__': #new.* are misrepeported
+ modname = 'new'
+ __import__(modname)
+ themodule = sys.modules[modname]
+ try:
+ klass = getattr(themodule, name)
+ except AttributeError, a:
+ #print themodule, name, obj, type(obj)
+ raise pickle.PicklingError("Can't pickle builtin %s" % obj)
+ else:
+ raise
+
+ except (ImportError, KeyError, AttributeError):
+ if typ == types.TypeType or typ == types.ClassType:
+ sendRef = False
+ else: #we can't deal with this
+ raise
+ else:
+ if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
+ sendRef = False
+ if not sendRef:
+ #note: Third party types might crash this - add better checks!
+ d = dict(obj.__dict__) #copy dict proxy to a dict
+ if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
+ d.pop('__dict__',None)
+ d.pop('__weakref__',None)
+
+ # hack as __new__ is stored differently in the __dict__
+ new_override = d.get('__new__', None)
+ if new_override:
+ d['__new__'] = obj.__new__
+
+ self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
+ d),obj=obj)
+ #print 'internal reduce dask %s %s' % (obj, d)
+ return
+
+ if self.proto >= 2:
+ code = _extension_registry.get((modname, name))
+ if code:
+ assert code > 0
+ if code <= 0xff:
+ write(pickle.EXT1 + chr(code))
+ elif code <= 0xffff:
+ write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
+ else:
+ write(pickle.EXT4 + pack("<i", code))
+ return
+
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.ClassType] = save_global
+ dispatch[types.BuiltinFunctionType] = save_global
+ dispatch[types.TypeType] = save_global
+
+ def save_instancemethod(self, obj):
+ #Memoization rarely is ever useful due to python bounding
+ self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
+ dispatch[types.MethodType] = save_instancemethod
+
+ def save_inst_logic(self, obj):
+ """Inner logic to save instance. Based off pickle.save_inst
+ Supports __transient__"""
+ cls = obj.__class__
+
+ memo = self.memo
+ write = self.write
+ save = self.save
+
+ if hasattr(obj, '__getinitargs__'):
+ args = obj.__getinitargs__()
+ len(args) # XXX Assert it's a sequence
+ pickle._keep_alive(args, memo)
+ else:
+ args = ()
+
+ write(pickle.MARK)
+
+ if self.bin:
+ save(cls)
+ for arg in args:
+ save(arg)
+ write(pickle.OBJ)
+ else:
+ for arg in args:
+ save(arg)
+ write(pickle.INST + cls.__module__ + '\n' + cls.__name__ + '\n')
+
+ self.memoize(obj)
+
+ try:
+ getstate = obj.__getstate__
+ except AttributeError:
+ stuff = obj.__dict__
+ #remove items if transient
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ stuff = stuff.copy()
+ for k in list(stuff.keys()):
+ if k in transient:
+ del stuff[k]
+ else:
+ stuff = getstate()
+ pickle._keep_alive(stuff, memo)
+ save(stuff)
+ write(pickle.BUILD)
+
+
+ def save_inst(self, obj):
+ # Hack to detect PIL Image instances without importing Imaging
+ # PIL can be loaded with multiple names, so we don't check sys.modules for it
+ if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
+ self.save_image(obj)
+ else:
+ self.save_inst_logic(obj)
+ dispatch[types.InstanceType] = save_inst
+
+ def save_property(self, obj):
+ # properties not correctly saved in python
+ self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj)
+ dispatch[property] = save_property
+
+ def save_itemgetter(self, obj):
+ """itemgetter serializer (needed for namedtuple support)
+ a bit of a pain as we need to read ctypes internals"""
+ class ItemGetterType(ctypes.Structure):
+ _fields_ = PyObject_HEAD + [
+ ('nitems', ctypes.c_size_t),
+ ('item', ctypes.py_object)
+ ]
+
+
+ itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
+ return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,))
+
+ if PyObject_HEAD:
+ dispatch[operator.itemgetter] = save_itemgetter
+
+
+
+ def save_reduce(self, func, args, state=None,
+ listitems=None, dictitems=None, obj=None):
+ """Modified to support __transient__ on new objects
+ Change only affects protocol level 2 (which is always used by PiCloud"""
+ # Assert that args is a tuple or None
+ if not isinstance(args, types.TupleType):
+ raise pickle.PicklingError("args from reduce() should be a tuple")
+
+ # Assert that func is callable
+ if not hasattr(func, '__call__'):
+ raise pickle.PicklingError("func from reduce should be callable")
+
+ save = self.save
+ write = self.write
+
+ # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
+ if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
+ #Added fix to allow transient
+ cls = args[0]
+ if not hasattr(cls, "__new__"):
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has no __new__")
+ if obj is not None and cls is not obj.__class__:
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has the wrong class")
+ args = args[1:]
+ save(cls)
+
+ #Don't pickle transient entries
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ state = state.copy()
+
+ for k in list(state.keys()):
+ if k in transient:
+ del state[k]
+
+ save(args)
+ write(pickle.NEWOBJ)
+ else:
+ save(func)
+ save(args)
+ write(pickle.REDUCE)
+
+ if obj is not None:
+ self.memoize(obj)
+
+ # More new special cases (that work with older protocols as
+ # well): when __reduce__ returns a tuple with 4 or 5 items,
+ # the 4th and 5th item should be iterators that provide list
+ # items and dict items (as (key, value) tuples), or None.
+
+ if listitems is not None:
+ self._batch_appends(listitems)
+
+ if dictitems is not None:
+ self._batch_setitems(dictitems)
+
+ if state is not None:
+ #print 'obj %s has state %s' % (obj, state)
+ save(state)
+ write(pickle.BUILD)
+
+
+ def save_xrange(self, obj):
+ """Save an xrange object in python 2.5
+ Python 2.6 supports this natively
+ """
+ range_params = xrange_params(obj)
+ self.save_reduce(_build_xrange,range_params)
+
+ #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
+ try:
+ xrange(0).__reduce__()
+ except TypeError: #can't pickle -- use PiCloud pickler
+ dispatch[xrange] = save_xrange
+
+ def save_partial(self, obj):
+ """Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
+ self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
+
+ if sys.version_info < (2,7): #2.7 supports partial pickling
+ dispatch[partial] = save_partial
+
+
+ def save_file(self, obj):
+ """Save a file"""
+ import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
+ from ..transport.adapter import SerializingAdapter
+
+ if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
+ raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
+ if obj.name == '<stdout>':
+ return self.save_reduce(getattr, (sys,'stdout'), obj=obj)
+ if obj.name == '<stderr>':
+ return self.save_reduce(getattr, (sys,'stderr'), obj=obj)
+ if obj.name == '<stdin>':
+ raise pickle.PicklingError("Cannot pickle standard input")
+ if hasattr(obj, 'isatty') and obj.isatty():
+ raise pickle.PicklingError("Cannot pickle files that map to tty objects")
+ if 'r' not in obj.mode:
+ raise pickle.PicklingError("Cannot pickle files that are not opened for reading")
+ name = obj.name
+ try:
+ fsize = os.stat(name).st_size
+ except OSError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name)
+
+ if obj.closed:
+ #create an empty closed string io
+ retval = pystringIO.StringIO("")
+ retval.close()
+ elif not fsize: #empty file
+ retval = pystringIO.StringIO("")
+ try:
+ tmpfile = file(name)
+ tst = tmpfile.read(1)
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ tmpfile.close()
+ if tst != '':
+ raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name)
+ elif fsize > SerializingAdapter.max_transmit_data:
+ raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" %
+ (name,SerializingAdapter.max_transmit_data))
+ else:
+ try:
+ tmpfile = file(name)
+ contents = tmpfile.read(SerializingAdapter.max_transmit_data)
+ tmpfile.close()
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ retval = pystringIO.StringIO(contents)
+ curloc = obj.tell()
+ retval.seek(curloc)
+
+ retval.name = name
+ self.save(retval) #save stringIO
+ self.memoize(obj)
+
+ dispatch[file] = save_file
+ """Special functions for Add-on libraries"""
+
+ def inject_numpy(self):
+ numpy = sys.modules.get('numpy')
+ if not numpy or not hasattr(numpy, 'ufunc'):
+ return
+ self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
+
+ numpy_tst_mods = ['numpy', 'scipy.special']
+ def save_ufunc(self, obj):
+ """Hack function for saving numpy ufunc objects"""
+ name = obj.__name__
+ for tst_mod_name in self.numpy_tst_mods:
+ tst_mod = sys.modules.get(tst_mod_name, None)
+ if tst_mod:
+ if name in tst_mod.__dict__:
+ self.save_reduce(_getobject, (tst_mod_name, name))
+ return
+ raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
+
+ def inject_timeseries(self):
+ """Handle bugs with pickling scikits timeseries"""
+ tseries = sys.modules.get('scikits.timeseries.tseries')
+ if not tseries or not hasattr(tseries, 'Timeseries'):
+ return
+ self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
+
+ def save_timeseries(self, obj):
+ import scikits.timeseries.tseries as ts
+
+ func, reduce_args, state = obj.__reduce__()
+ if func != ts._tsreconstruct:
+ raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
+ state = (1,
+ obj.shape,
+ obj.dtype,
+ obj.flags.fnc,
+ obj._data.tostring(),
+ ts.getmaskarray(obj).tostring(),
+ obj._fill_value,
+ obj._dates.shape,
+ obj._dates.__array__().tostring(),
+ obj._dates.dtype, #added -- preserve type
+ obj.freq,
+ obj._optinfo,
+ )
+ return self.save_reduce(_genTimeSeries, (reduce_args, state))
+
+ def inject_email(self):
+ """Block email LazyImporters from being saved"""
+ email = sys.modules.get('email')
+ if not email:
+ return
+ self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
+
+ def inject_addons(self):
+ """Plug in system. Register additional pickling functions if modules already loaded"""
+ self.inject_numpy()
+ self.inject_timeseries()
+ self.inject_email()
+
+ """Python Imaging Library"""
+ def save_image(self, obj):
+ if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
+ and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
+ #if image not loaded yet -- lazy load
+ self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
+ else:
+ #image is loaded - just transmit it over
+ self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
+
+ """
+ def memoize(self, obj):
+ pickle.Pickler.memoize(self, obj)
+ if printMemoization:
+ print 'memoizing ' + str(obj)
+ """
+
+
+
+# Shorthands for legacy support
+
+def dump(obj, file, protocol=2):
+ CloudPickler(file, protocol).dump(obj)
+
+def dumps(obj, protocol=2):
+ file = StringIO()
+
+ cp = CloudPickler(file,protocol)
+ cp.dump(obj)
+
+ #print 'cloud dumped', str(obj), str(cp.modules)
+
+ return file.getvalue()
+
+
+#hack for __import__ not working as desired
+def subimport(name):
+ __import__(name)
+ return sys.modules[name]
+
+#hack to load django settings:
+def django_settings_load(name):
+ modified_env = False
+
+ if 'DJANGO_SETTINGS_MODULE' not in os.environ:
+ os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
+ modified_env = True
+ try:
+ module = subimport(name)
+ except Exception, i:
+ print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
+ print_exec(sys.stderr)
+ if modified_env:
+ del os.environ['DJANGO_SETTINGS_MODULE']
+ else:
+ #add project directory to sys,path:
+ if hasattr(module,'__file__'):
+ dirname = os.path.split(module.__file__)[0] + '/'
+ sys.path.append(dirname)
+
+# restores function attributes
+def _restore_attr(obj, attr):
+ for key, val in attr.items():
+ setattr(obj, key, val)
+ return obj
+
+def _get_module_builtins():
+ return pickle.__builtins__
+
+def print_exec(stream):
+ ei = sys.exc_info()
+ traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
+
+def _modules_to_main(modList):
+ """Force every module in modList to be placed into main"""
+ if not modList:
+ return
+
+ main = sys.modules['__main__']
+ for modname in modList:
+ if type(modname) is str:
+ try:
+ mod = __import__(modname)
+ except Exception, i: #catch all...
+ sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
+A version mismatch is likely. Specific error was:\n' % modname)
+ print_exec(sys.stderr)
+ else:
+ setattr(main,mod.__name__, mod)
+ else:
+ #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
+ #In old version actual module was sent
+ setattr(main,modname.__name__, modname)
+
+#object generators:
+def _build_xrange(start, step, len):
+ """Built xrange explicitly"""
+ return xrange(start, start + step*len, step)
+
+def _genpartial(func, args, kwds):
+ if not args:
+ args = ()
+ if not kwds:
+ kwds = {}
+ return partial(func, *args, **kwds)
+
+
+def _fill_function(func, globals, defaults, closure, dict):
+ """ Fills in the rest of function data into the skeleton function object
+ that were created via _make_skel_func().
+ """
+ func.func_globals.update(globals)
+ func.func_defaults = defaults
+ func.func_dict = dict
+
+ if len(closure) != len(func.func_closure):
+ raise pickle.UnpicklingError("closure lengths don't match up")
+ for i in range(len(closure)):
+ _change_cell_value(func.func_closure[i], closure[i])
+
+ return func
+
+def _make_skel_func(code, num_closures, base_globals = None):
+ """ Creates a skeleton function object that contains just the provided
+ code and the correct number of cells in func_closure. All other
+ func attributes (e.g. func_globals) are empty.
+ """
+ #build closure (cells):
+ if not ctypes:
+ raise Exception('ctypes failed to import; cannot build function')
+
+ cellnew = ctypes.pythonapi.PyCell_New
+ cellnew.restype = ctypes.py_object
+ cellnew.argtypes = (ctypes.py_object,)
+ dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures)))
+
+ if base_globals is None:
+ base_globals = {}
+ base_globals['__builtins__'] = __builtins__
+
+ return types.FunctionType(code, base_globals,
+ None, None, dummy_closure)
+
+# this piece of opaque code is needed below to modify 'cell' contents
+cell_changer_code = new.code(
+ 1, 1, 2, 0,
+ ''.join([
+ chr(dis.opmap['LOAD_FAST']), '\x00\x00',
+ chr(dis.opmap['DUP_TOP']),
+ chr(dis.opmap['STORE_DEREF']), '\x00\x00',
+ chr(dis.opmap['RETURN_VALUE'])
+ ]),
+ (), (), ('newval',), '<nowhere>', 'cell_changer', 1, '', ('c',), ()
+)
+
+def _change_cell_value(cell, newval):
+ """ Changes the contents of 'cell' object to newval """
+ return new.function(cell_changer_code, {}, None, (), (cell,))(newval)
+
+"""Constructors for 3rd party libraries
+Note: These can never be renamed due to client compatibility issues"""
+
+def _getobject(modname, attribute):
+ mod = __import__(modname)
+ return mod.__dict__[attribute]
+
+def _generateImage(size, mode, str_rep):
+ """Generate image from string representation"""
+ import Image
+ i = Image.new(mode, size)
+ i.fromstring(str_rep)
+ return i
+
+def _lazyloadImage(fp):
+ import Image
+ fp.seek(0) #works in almost any case
+ return Image.open(fp)
+
+"""Timeseries"""
+def _genTimeSeries(reduce_args, state):
+ import scikits.timeseries.tseries as ts
+ from numpy import ndarray
+ from numpy.ma import MaskedArray
+
+
+ time_series = ts._tsreconstruct(*reduce_args)
+
+ #from setstate modified
+ (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
+ #print 'regenerating %s' % dtyp
+
+ MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
+ _dates = time_series._dates
+ #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
+ ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
+ _dates.freq = frq
+ _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
+ toobj=None, toord=None, tostr=None))
+ # Update the _optinfo dictionary
+ time_series._optinfo.update(infodict)
+ return time_series
+
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
new file mode 100644
index 0000000000..e486f206b0
--- /dev/null
+++ b/python/pyspark/context.py
@@ -0,0 +1,159 @@
+import os
+import atexit
+from tempfile import NamedTemporaryFile
+
+from pyspark.broadcast import Broadcast
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.rdd import RDD
+
+from py4j.java_collections import ListConverter
+
+
+class SparkContext(object):
+ """
+ Main entry point for Spark functionality. A SparkContext represents the
+ connection to a Spark cluster, and can be used to create L{RDD}s and
+ broadcast variables on that cluster.
+ """
+
+ gateway = launch_gateway()
+ jvm = gateway.jvm
+ _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+ _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
+ _takePartition = jvm.PythonRDD.takePartition
+
+ def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
+ environment=None, batchSize=1024):
+ """
+ Create a new SparkContext.
+
+ @param master: Cluster URL to connect to
+ (e.g. mesos://host:port, spark://host:port, local[4]).
+ @param jobName: A name for your job, to display on the cluster web UI
+ @param sparkHome: Location where Spark is installed on cluster nodes.
+ @param pyFiles: Collection of .zip or .py files to send to the cluster
+ and add to PYTHONPATH. These can be paths on the local file
+ system or HDFS, HTTP, HTTPS, or FTP URLs.
+ @param environment: A dictionary of environment variables to set on
+ worker nodes.
+ @param batchSize: The number of Python objects represented as a single
+ Java object. Set 1 to disable batching or -1 to use an
+ unlimited batch size.
+ """
+ self.master = master
+ self.jobName = jobName
+ self.sparkHome = sparkHome or None # None becomes null in Py4J
+ self.environment = environment or {}
+ self.batchSize = batchSize # -1 represents a unlimited batch size
+
+ # Create the Java SparkContext through Py4J
+ empty_string_array = self.gateway.new_array(self.jvm.String, 0)
+ self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array)
+
+ self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+ # Broadcast's __reduce__ method stores Broadcast instances here.
+ # This allows other code to determine which Broadcast instances have
+ # been pickled, so it can determine which Java broadcast objects to
+ # send.
+ self._pickled_broadcast_vars = set()
+
+ # Deploy any code dependencies specified in the constructor
+ for path in (pyFiles or []):
+ self.addPyFile(path)
+
+ @property
+ def defaultParallelism(self):
+ """
+ Default level of parallelism to use when not given by user (e.g. for
+ reduce tasks)
+ """
+ return self._jsc.sc().defaultParallelism()
+
+ def __del__(self):
+ if self._jsc:
+ self._jsc.stop()
+
+ def stop(self):
+ """
+ Shut down the SparkContext.
+ """
+ self._jsc.stop()
+ self._jsc = None
+
+ def parallelize(self, c, numSlices=None):
+ """
+ Distribute a local Python collection to form an RDD.
+ """
+ numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
+ tempFile = NamedTemporaryFile(delete=False)
+ atexit.register(lambda: os.unlink(tempFile.name))
+ if self.batchSize != 1:
+ c = batched(c, self.batchSize)
+ for x in c:
+ write_with_length(dump_pickle(x), tempFile)
+ tempFile.close()
+ jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self)
+
+ def textFile(self, name, minSplits=None):
+ """
+ Read a text file from HDFS, a local file system (available on all
+ nodes), or any Hadoop-supported file system URI, and return it as an
+ RDD of Strings.
+ """
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jsc.textFile(name, minSplits)
+ return RDD(jrdd, self)
+
+ def union(self, rdds):
+ """
+ Build the union of a list of RDDs.
+ """
+ first = rdds[0]._jrdd
+ rest = [x._jrdd for x in rdds[1:]]
+ rest = ListConverter().convert(rest, self.gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self)
+
+ def broadcast(self, value):
+ """
+ Broadcast a read-only variable to the cluster, returning a C{Broadcast}
+ object for reading it in distributed functions. The variable will be
+ sent to each cluster only once.
+ """
+ jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ return Broadcast(jbroadcast.id(), value, jbroadcast,
+ self._pickled_broadcast_vars)
+
+ def addFile(self, path):
+ """
+ Add a file to be downloaded into the working directory of this Spark
+ job on every node. The C{path} passed can be either a local file,
+ a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
+ HTTPS or FTP URI.
+ """
+ self._jsc.sc().addFile(path)
+
+ def clearFiles(self):
+ """
+ Clear the job's list of files added by L{addFile} or L{addPyFile} so
+ that they do not get downloaded to any new nodes.
+ """
+ # TODO: remove added .py or .zip files from the PYTHONPATH?
+ self._jsc.sc().clearFiles()
+
+ def addPyFile(self, path):
+ """
+ Add a .py or .zip dependency for all tasks to be executed on this
+ SparkContext in the future. The C{path} passed can be either a local
+ file, a file in HDFS (or other Hadoop-supported filesystems), or an
+ HTTP, HTTPS or FTP URI.
+ """
+ self.addFile(path)
+ filename = path.split("/")[-1]
+ os.environ["PYTHONPATH"] = \
+ "%s:%s" % (filename, os.environ["PYTHONPATH"])
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
new file mode 100644
index 0000000000..2329e536cc
--- /dev/null
+++ b/python/pyspark/java_gateway.py
@@ -0,0 +1,38 @@
+import os
+import sys
+from subprocess import Popen, PIPE
+from threading import Thread
+from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+
+
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
+def launch_gateway():
+ # Launch the Py4j gateway using Spark's run command so that we pick up the
+ # proper classpath and SPARK_MEM settings from spark-env.sh
+ command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer",
+ "--die-on-broken-pipe", "0"]
+ proc = Popen(command, stdout=PIPE, stdin=PIPE)
+ # Determine which ephemeral port the server started on:
+ port = int(proc.stdout.readline())
+ # Create a thread to echo output from the GatewayServer, which is required
+ # for Java log output to show up:
+ class EchoOutputThread(Thread):
+ def __init__(self, stream):
+ Thread.__init__(self)
+ self.daemon = True
+ self.stream = stream
+
+ def run(self):
+ while True:
+ line = self.stream.readline()
+ sys.stderr.write(line)
+ EchoOutputThread(proc.stdout).start()
+ # Connect to the gateway
+ gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
+ # Import the classes used by PySpark
+ java_import(gateway.jvm, "spark.api.java.*")
+ java_import(gateway.jvm, "spark.api.python.*")
+ java_import(gateway.jvm, "scala.Tuple2")
+ return gateway
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
new file mode 100644
index 0000000000..7036c47980
--- /dev/null
+++ b/python/pyspark/join.py
@@ -0,0 +1,92 @@
+"""
+Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+ * Neither the name of the Douban Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+def _do_python_join(rdd, other, numSplits, dispatch):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
+
+
+def python_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_right_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not vbuf:
+ vbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_left_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not wbuf:
+ wbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_cogroup(rdd, other, numSplits):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return (vbuf, wbuf)
+ return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
new file mode 100644
index 0000000000..1d36da42b0
--- /dev/null
+++ b/python/pyspark/rdd.py
@@ -0,0 +1,723 @@
+import atexit
+from base64 import standard_b64encode as b64enc
+import copy
+from collections import defaultdict
+from itertools import chain, ifilter, imap, product
+import operator
+import os
+import shlex
+from subprocess import Popen, PIPE
+from tempfile import NamedTemporaryFile
+from threading import Thread
+
+from pyspark import cloudpickle
+from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
+ read_from_pickle_file
+from pyspark.join import python_join, python_left_outer_join, \
+ python_right_outer_join, python_cogroup
+
+from py4j.java_collections import ListConverter, MapConverter
+
+
+__all__ = ["RDD"]
+
+
+class RDD(object):
+ """
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ Represents an immutable, partitioned collection of elements that can be
+ operated on in parallel.
+ """
+
+ def __init__(self, jrdd, ctx):
+ self._jrdd = jrdd
+ self.is_cached = False
+ self.ctx = ctx
+
+ @property
+ def context(self):
+ """
+ The L{SparkContext} that this RDD was created on.
+ """
+ return self.ctx
+
+ def cache(self):
+ """
+ Persist this RDD with the default storage level (C{MEMORY_ONLY}).
+ """
+ self.is_cached = True
+ self._jrdd.cache()
+ return self
+
+ # TODO persist(self, storageLevel)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+ """
+ def func(split, iterator): return imap(f, iterator)
+ return PipelinedRDD(self, func, preservesPartitioning)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by first applying a function to all elements of this
+ RDD, and then flattening the results.
+
+ >>> rdd = sc.parallelize([2, 3, 4])
+ >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
+ [1, 1, 1, 2, 2, 3]
+ >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
+ [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
+ """
+ def func(s, iterator): return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithSplit(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> def f(iterator): yield sum(iterator)
+ >>> rdd.mapPartitions(f).collect()
+ [3, 7]
+ """
+ def func(s, iterator): return f(iterator)
+ return self.mapPartitionsWithSplit(func)
+
+ def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD,
+ while tracking the index of the original partition.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(splitIndex, iterator): yield splitIndex
+ >>> rdd.mapPartitionsWithSplit(f).sum()
+ 6
+ """
+ return PipelinedRDD(self, f, preservesPartitioning)
+
+ def filter(self, f):
+ """
+ Return a new RDD containing only the elements that satisfy a predicate.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
+ >>> rdd.filter(lambda x: x % 2 == 0).collect()
+ [2, 4]
+ """
+ def func(iterator): return ifilter(f, iterator)
+ return self.mapPartitions(func)
+
+ def distinct(self):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+
+ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
+ [1, 2, 3]
+ """
+ return self.map(lambda x: (x, "")) \
+ .reduceByKey(lambda x, _: x) \
+ .map(lambda (x, _): x)
+
+ # TODO: sampling needs to be re-implemented due to Batch
+ #def sample(self, withReplacement, fraction, seed):
+ # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
+ # return RDD(jrdd, self.ctx)
+
+ #def takeSample(self, withReplacement, num, seed):
+ # vals = self._jrdd.takeSample(withReplacement, num, seed)
+ # return [load_pickle(bytes(x)) for x in vals]
+
+ def union(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> rdd.union(rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ return RDD(self._jrdd.union(other._jrdd), self.ctx)
+
+ def __add__(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> (rdd + rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ if not isinstance(other, RDD):
+ raise TypeError
+ return self.union(other)
+
+ # TODO: sort
+
+ def glom(self):
+ """
+ Return an RDD created by coalescing all elements within each partition
+ into a list.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> sorted(rdd.glom().collect())
+ [[1, 2], [3, 4]]
+ """
+ def func(iterator): yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cartesian(self, other):
+ """
+ Return the Cartesian product of this RDD and another one, that is, the
+ RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and
+ C{b} is in C{other}.
+
+ >>> rdd = sc.parallelize([1, 2])
+ >>> sorted(rdd.cartesian(rdd).collect())
+ [(1, 1), (1, 2), (2, 1), (2, 2)]
+ """
+ # Due to batching, we can't use the Java cartesian method.
+ java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ def unpack_batches(pair):
+ (x, y) = pair
+ if type(x) == Batch or type(y) == Batch:
+ xs = x.items if type(x) == Batch else [x]
+ ys = y.items if type(y) == Batch else [y]
+ for pair in product(xs, ys):
+ yield pair
+ else:
+ yield pair
+ return java_cartesian.flatMap(unpack_batches)
+
+ def groupBy(self, f, numSplits=None):
+ """
+ Return an RDD of grouped items.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
+ >>> result = rdd.groupBy(lambda x: x % 2).collect()
+ >>> sorted([(x, sorted(y)) for (x, y) in result])
+ [(0, [2, 8]), (1, [1, 1, 3, 5])]
+ """
+ return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
+
+ def pipe(self, command, env={}):
+ """
+ Return an RDD created by piping elements to a forked external process.
+
+ >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
+ ['1', '2', '3']
+ """
+ def func(iterator):
+ pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+ def pipe_objs(out):
+ for obj in iterator:
+ out.write(str(obj).rstrip('\n') + '\n')
+ out.close()
+ Thread(target=pipe_objs, args=[pipe.stdin]).start()
+ return (x.rstrip('\n') for x in pipe.stdout)
+ return self.mapPartitions(func)
+
+ def foreach(self, f):
+ """
+ Applies a function to all elements of this RDD.
+
+ >>> def f(x): print x
+ >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
+ """
+ self.map(f).collect() # Force evaluation
+
+ def collect(self):
+ """
+ Return a list that contains all of the elements in this RDD.
+ """
+ picklesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(picklesInJava))
+
+ def _collect_iterator_through_file(self, iterator):
+ # Transferring lots of data through Py4J can be slow because
+ # socket.readline() is inefficient. Instead, we'll dump the data to a
+ # file and read it back.
+ tempFile = NamedTemporaryFile(delete=False)
+ tempFile.close()
+ def clean_up_file():
+ try: os.unlink(tempFile.name)
+ except: pass
+ atexit.register(clean_up_file)
+ self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ for item in read_from_pickle_file(tempFile):
+ yield item
+ os.unlink(tempFile.name)
+
+ def reduce(self, f):
+ """
+ Reduces the elements of this RDD using the specified associative binary
+ operator.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
+ 15
+ >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
+ 10
+ """
+ def func(iterator):
+ acc = None
+ for obj in iterator:
+ if acc is None:
+ acc = obj
+ else:
+ acc = f(obj, acc)
+ if acc is not None:
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(f, vals)
+
+ def fold(self, zeroValue, op):
+ """
+ Aggregate the elements of each partition, and then the results for all
+ the partitions, using a given associative function and a neutral "zero
+ value."
+
+ The function C{op(t1, t2)} is allowed to modify C{t1} and return it
+ as its result value to avoid object allocation; however, it should not
+ modify C{t2}.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
+ 15
+ """
+ def func(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = op(obj, acc)
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(op, vals, zeroValue)
+
+ # TODO: aggregate
+
+ def sum(self):
+ """
+ Add up the elements in this RDD.
+
+ >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
+ 6.0
+ """
+ return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+
+ def count(self):
+ """
+ Return the number of elements in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).count()
+ 3
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def countByValue(self):
+ """
+ Return the count of each unique value in this RDD as a dictionary of
+ (value, count) pairs.
+
+ >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
+ [(1, 2), (2, 3)]
+ """
+ def countPartition(iterator):
+ counts = defaultdict(int)
+ for obj in iterator:
+ counts[obj] += 1
+ yield counts
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] += v
+ return m1
+ return self.mapPartitions(countPartition).reduce(mergeMaps)
+
+ def take(self, num):
+ """
+ Take the first num elements of the RDD.
+
+ This currently scans the partitions *one by one*, so it will be slow if
+ a lot of partitions are required. In that case, use L{collect} to get
+ the whole RDD instead.
+
+ >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
+ [2, 3]
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+ [2, 3, 4, 5, 6]
+ """
+ items = []
+ for partition in range(self._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
+ items.extend(self._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
+ return items[:num]
+
+ def first(self):
+ """
+ Return the first element in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).first()
+ 2
+ """
+ return self.take(1)[0]
+
+ def saveAsTextFile(self, path):
+ """
+ Save this RDD as a text file, using string representations of elements.
+
+ >>> tempFile = NamedTemporaryFile(delete=True)
+ >>> tempFile.close()
+ >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
+ >>> from fileinput import input
+ >>> from glob import glob
+ >>> ''.join(input(glob(tempFile.name + "/part-0000*")))
+ '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
+ """
+ def func(split, iterator):
+ return (str(x).encode("utf-8") for x in iterator)
+ keyed = PipelinedRDD(self, func)
+ keyed._bypass_serializer = True
+ keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+
+ # Pair functions
+
+ def collectAsMap(self):
+ """
+ Return the key-value pairs in this RDD to the master as a dictionary.
+
+ >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
+ >>> m[1]
+ 2
+ >>> m[3]
+ 4
+ """
+ return dict(self.collect())
+
+ def reduceByKey(self, func, numSplits=None):
+ """
+ Merge the values for each key using an associative reduce function.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ Output will be hash-partitioned with C{numSplits} splits, or the
+ default parallelism level if C{numSplits} is not specified.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKey(add).collect())
+ [('a', 2), ('b', 1)]
+ """
+ return self.combineByKey(lambda x: x, func, func, numSplits)
+
+ def reduceByKeyLocally(self, func):
+ """
+ Merge the values for each key using an associative reduce function, but
+ return the results immediately to the master as a dictionary.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKeyLocally(add).items())
+ [('a', 2), ('b', 1)]
+ """
+ def reducePartition(iterator):
+ m = {}
+ for (k, v) in iterator:
+ m[k] = v if k not in m else func(m[k], v)
+ yield m
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] = v if k not in m1 else func(m1[k], v)
+ return m1
+ return self.mapPartitions(reducePartition).reduce(mergeMaps)
+
+ def countByKey(self):
+ """
+ Count the number of elements for each key, and return the result to the
+ master as a dictionary.
+
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.countByKey().items())
+ [('a', 2), ('b', 1)]
+ """
+ return self.map(lambda x: x[0]).countByValue()
+
+ def join(self, other, numSplits=None):
+ """
+ Return an RDD containing all pairs of elements with matching keys in
+ C{self} and C{other}.
+
+ Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
+ (k, v1) is in C{self} and (k, v2) is in C{other}.
+
+ Performs a hash join across the cluster.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2), ("a", 3)])
+ >>> sorted(x.join(y).collect())
+ [('a', (1, 2)), ('a', (1, 3))]
+ """
+ return python_join(self, other, numSplits)
+
+ def leftOuterJoin(self, other, numSplits=None):
+ """
+ Perform a left outer join of C{self} and C{other}.
+
+ For each element (k, v) in C{self}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for w in C{other}, or the pair
+ (k, (v, None)) if no elements in other have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.leftOuterJoin(y).collect())
+ [('a', (1, 2)), ('b', (4, None))]
+ """
+ return python_left_outer_join(self, other, numSplits)
+
+ def rightOuterJoin(self, other, numSplits=None):
+ """
+ Perform a right outer join of C{self} and C{other}.
+
+ For each element (k, w) in C{other}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
+ if no elements in C{self} have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(y.rightOuterJoin(x).collect())
+ [('a', (2, 1)), ('b', (None, 4))]
+ """
+ return python_right_outer_join(self, other, numSplits)
+
+ # TODO: add option to control map-side combining
+ def partitionBy(self, numSplits, hashFunc=hash):
+ """
+ Return a copy of the RDD partitioned using the specified partitioner.
+
+ >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+ >>> sets = pairs.partitionBy(2).glom().collect()
+ >>> set(sets[0]).intersection(set(sets[1]))
+ set([])
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ # Transferring O(n) objects to Java is too expensive. Instead, we'll
+ # form the hash buckets in Python, transferring O(numSplits) objects
+ # to Java. Each object is a (splitNumber, [objects]) pair.
+ def add_shuffle_key(split, iterator):
+ buckets = defaultdict(list)
+ for (k, v) in iterator:
+ buckets[hashFunc(k) % numSplits].append((k, v))
+ for (split, items) in buckets.iteritems():
+ yield str(split)
+ yield dump_pickle(Batch(items))
+ keyed = PipelinedRDD(self, add_shuffle_key)
+ keyed._bypass_serializer = True
+ pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
+ jrdd = pairRDD.partitionBy(partitioner)
+ jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
+ return RDD(jrdd, self.ctx)
+
+ # TODO: add control over map-side aggregation
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numSplits=None):
+ """
+ Generic function to combine the elements for each key using a custom
+ set of aggregation functions.
+
+ Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
+ type" C. Note that V and C can be different -- for example, one might
+ group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]).
+
+ Users provide three functions:
+
+ - C{createCombiner}, which turns a V into a C (e.g., creates
+ a one-element list)
+ - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of
+ a list)
+ - C{mergeCombiners}, to combine two C's into a single one.
+
+ In addition, users can control the partitioning of the output RDD.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> def f(x): return x
+ >>> def add(a, b): return a + str(b)
+ >>> sorted(x.combineByKey(str, add, add).collect())
+ [('a', '11'), ('b', '1')]
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ def combineLocally(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if k not in combiners:
+ combiners[k] = createCombiner(v)
+ else:
+ combiners[k] = mergeValue(combiners[k], v)
+ return combiners.iteritems()
+ locally_combined = self.mapPartitions(combineLocally)
+ shuffled = locally_combined.partitionBy(numSplits)
+ def _mergeCombiners(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if not k in combiners:
+ combiners[k] = v
+ else:
+ combiners[k] = mergeCombiners(combiners[k], v)
+ return combiners.iteritems()
+ return shuffled.mapPartitions(_mergeCombiners)
+
+ # TODO: support variant with custom partitioner
+ def groupByKey(self, numSplits=None):
+ """
+ Group the values for each key in the RDD into a single sequence.
+ Hash-partitions the resulting RDD with into numSplits partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(x.groupByKey().collect())
+ [('a', [1, 1]), ('b', [1])]
+ """
+
+ def createCombiner(x):
+ return [x]
+
+ def mergeValue(xs, x):
+ xs.append(x)
+ return xs
+
+ def mergeCombiners(a, b):
+ return a + b
+
+ return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
+ numSplits)
+
+ # TODO: add tests
+ def flatMapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a flatMap function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def mapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a map function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ # TODO: support varargs cogroup of several RDDs.
+ def groupWith(self, other):
+ """
+ Alias for cogroup.
+ """
+ return self.cogroup(other)
+
+ # TODO: add variant with custom parittioner
+ def cogroup(self, other, numSplits=None):
+ """
+ For each key k in C{self} or C{other}, return a resulting RDD that
+ contains a tuple with the list of values for that key in C{self} as well
+ as C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.cogroup(y).collect())
+ [('a', ([1], [2])), ('b', ([4], []))]
+ """
+ return python_cogroup(self, other, numSplits)
+
+ # TODO: `lookup` is disabled because we can't make direct comparisons based
+ # on the key; we need to compare the hash of the key to the hash of the
+ # keys in the pairs. This could be an expensive operation, since those
+ # hashes aren't retained.
+
+
+class PipelinedRDD(RDD):
+ """
+ Pipelined maps:
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+
+ Pipelined reduces:
+ >>> from operator import add
+ >>> rdd.map(lambda x: 2 * x).reduce(add)
+ 20
+ >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
+ 20
+ """
+ def __init__(self, prev, func, preservesPartitioning=False):
+ if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ prev_func = prev.func
+ def pipeline_func(split, iterator):
+ return func(split, prev_func(split, iterator))
+ self.func = pipeline_func
+ self.preservesPartitioning = \
+ prev.preservesPartitioning and preservesPartitioning
+ self._prev_jrdd = prev._prev_jrdd
+ else:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self.is_cached = False
+ self.ctx = prev.ctx
+ self.prev = prev
+ self._jrdd_val = None
+ self._bypass_serializer = False
+
+ @property
+ def _jrdd(self):
+ if self._jrdd_val:
+ return self._jrdd_val
+ func = self.func
+ if not self._bypass_serializer and self.ctx.batchSize != 1:
+ oldfunc = self.func
+ batchSize = self.ctx.batchSize
+ def batched_func(split, iterator):
+ return batched(oldfunc(split, iterator), batchSize)
+ func = batched_func
+ cmds = [func, self._bypass_serializer]
+ pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
+ self.ctx.gateway._gateway_client)
+ self.ctx._pickled_broadcast_vars.clear()
+ class_manifest = self._prev_jrdd.classManifest()
+ env = copy.copy(self.ctx.environment)
+ env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
+ env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
+ python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ broadcast_vars, class_manifest)
+ self._jrdd_val = python_rdd.asJavaRDD()
+ return self._jrdd_val
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
new file mode 100644
index 0000000000..9a5151ea00
--- /dev/null
+++ b/python/pyspark/serializers.py
@@ -0,0 +1,78 @@
+import struct
+import cPickle
+
+
+class Batch(object):
+ """
+ Used to store multiple RDD entries as a single Java object.
+
+ This relieves us from having to explicitly track whether an RDD
+ is stored as batches of objects and avoids problems when processing
+ the union() of batched and unbatched RDDs (e.g. the union() of textFile()
+ with another RDD).
+ """
+ def __init__(self, items):
+ self.items = items
+
+
+def batched(iterator, batchSize):
+ if batchSize == -1: # unlimited batch size
+ yield Batch(list(iterator))
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == batchSize:
+ yield Batch(items)
+ items = []
+ count = 0
+ if items:
+ yield Batch(items)
+
+
+def dump_pickle(obj):
+ return cPickle.dumps(obj, 2)
+
+
+load_pickle = cPickle.loads
+
+
+def read_long(stream):
+ length = stream.read(8)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!q", length)[0]
+
+
+def read_int(stream):
+ length = stream.read(4)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!i", length)[0]
+
+def write_with_length(obj, stream):
+ stream.write(struct.pack("!i", len(obj)))
+ stream.write(obj)
+
+
+def read_with_length(stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return obj
+
+
+def read_from_pickle_file(stream):
+ try:
+ while True:
+ obj = load_pickle(read_with_length(stream))
+ if type(obj) == Batch: # We don't care about inheritance
+ for item in obj.items:
+ yield item
+ else:
+ yield obj
+ except EOFError:
+ return
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
new file mode 100644
index 0000000000..7e6ad3aa76
--- /dev/null
+++ b/python/pyspark/shell.py
@@ -0,0 +1,17 @@
+"""
+An interactive shell.
+
+This fle is designed to be launched as a PYTHONSTARTUP script.
+"""
+import os
+from pyspark.context import SparkContext
+
+
+sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell")
+print "Spark context avaiable as sc."
+
+# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
+# which allows us to execute the user's PYTHONSTARTUP file:
+_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
+if _pythonstartup and os.path.isfile(_pythonstartup):
+ execfile(_pythonstartup)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
new file mode 100644
index 0000000000..3d792bbaa2
--- /dev/null
+++ b/python/pyspark/worker.py
@@ -0,0 +1,42 @@
+"""
+Worker that receives input from Piped RDD.
+"""
+import sys
+from base64 import standard_b64decode
+# CloudPickler needs to be imported so that depicklers are registered using the
+# copy_reg module.
+from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import write_with_length, read_with_length, \
+ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+
+
+# Redirect stdout to stderr so that users must return values from functions.
+old_stdout = sys.stdout
+sys.stdout = sys.stderr
+
+
+def load_obj():
+ return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+
+
+def main():
+ split_index = read_int(sys.stdin)
+ num_broadcast_variables = read_int(sys.stdin)
+ for _ in range(num_broadcast_variables):
+ bid = read_long(sys.stdin)
+ value = read_with_length(sys.stdin)
+ _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ func = load_obj()
+ bypassSerializer = load_obj()
+ if bypassSerializer:
+ dumps = lambda x: x
+ else:
+ dumps = dump_pickle
+ iterator = read_from_pickle_file(sys.stdin)
+ for obj in func(split_index, iterator):
+ write_with_length(dumps(obj), old_stdout)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/python/run-tests b/python/run-tests
new file mode 100755
index 0000000000..fcdd1e27a7
--- /dev/null
+++ b/python/run-tests
@@ -0,0 +1,26 @@
+#!/usr/bin/env bash
+
+# Figure out where the Scala framework is installed
+FWDIR="$(cd `dirname $0`; cd ../; pwd)"
+
+FAILED=0
+
+$FWDIR/pyspark pyspark/rdd.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark -m doctest pyspark/broadcast.py
+FAILED=$(($?||$FAILED))
+
+if [[ $FAILED != 0 ]]; then
+ echo -en "\033[31m" # Red
+ echo "Had test failures; see logs."
+ echo -en "\033[0m" # No color
+ exit -1
+else
+ echo -en "\033[32m" # Green
+ echo "Tests passed."
+ echo -en "\033[0m" # No color
+fi
+
+# TODO: in the long-run, it would be nice to use a test runner like `nose`.
+# The doctest fixtures are the current barrier to doing this.
diff --git a/run b/run
index 6cfe9631af..ca23455386 100755
--- a/run
+++ b/run
@@ -63,6 +63,7 @@ CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
+PYSPARK_DIR="$FWDIR/python"
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
@@ -83,6 +84,9 @@ for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH+=":$jar"
done
CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
+for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
+ CLASSPATH+=":$jar"
+done
export CLASSPATH # Needed for spark-shell
# Figure out whether to run our class with java or with the scala launcher.
diff --git a/run2.cmd b/run2.cmd
index 333d0506b0..83464b1166 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -34,6 +34,7 @@ set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
+set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
@@ -42,6 +43,7 @@ set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMP
for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
+for /R "%PYSPARK_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
rem Figure out whether to run our class with java or with the scala launcher.