aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala9
-rw-r--r--python/pyspark/context.py14
-rw-r--r--python/pyspark/rdd.py4
-rw-r--r--python/pyspark/tests.py11
-rw-r--r--python/pyspark/worker.py13
-rw-r--r--python/test_support/userlib-0.1-py2.7.eggbin0 -> 1945 bytes
6 files changed, 45 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 2dd79f7100..49671437d0 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -33,6 +33,7 @@ private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: JMap[String, String],
+ pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
@@ -44,10 +45,11 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
+ pythonIncludes: JList[String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
+ this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
broadcastVars, accumulator)
override def getPartitions = parent.partitions
@@ -79,6 +81,11 @@ private[spark] class PythonRDD[T: ClassManifest](
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
+ // Python includes (*.zip and *.egg files)
+ dataOut.writeInt(pythonIncludes.length)
+ for (f <- pythonIncludes) {
+ PythonRDD.writeAsPickle(f, dataOut)
+ }
dataOut.flush()
// Serialized user code
for (elem <- command) {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index c2b49ff37a..2803ce90f3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -46,6 +46,7 @@ class SparkContext(object):
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
+ _python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -103,11 +104,14 @@ class SparkContext(object):
# send.
self._pickled_broadcast_vars = set()
+ SparkFiles._sc = self
+ root_dir = SparkFiles.getRootDirectory()
+ sys.path.append(root_dir)
+
# Deploy any code dependencies specified in the constructor
+ self._python_includes = list()
for path in (pyFiles or []):
self.addPyFile(path)
- SparkFiles._sc = self
- sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.spark.Utils.getLocalDir()
@@ -257,7 +261,11 @@ class SparkContext(object):
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
- filename = path.split("/")[-1]
+ (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
+
+ if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
+ self._python_includes.append(filename)
+ sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
def setCheckpointDir(self, dirName, useExisting=False):
"""
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 51c2cb9806..99f5967a8e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -758,8 +758,10 @@ class PipelinedRDD(RDD):
class_manifest = self._prev_jrdd.classManifest()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
+ includes = ListConverter().convert(self.ctx._python_includes,
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f75215a781..29d6a128f6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -125,6 +125,17 @@ class TestAddFile(PySparkTestCase):
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
+ def test_add_egg_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlib import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
+ self.sc.addPyFile(path)
+ from userlib import UserClass
+ self.assertEqual("Hello World from inside a package!", UserClass().hello())
+
class TestIO(PySparkTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 75d692beeb..695f6dfb84 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -49,15 +49,26 @@ def main(infile, outfile):
split_index = read_int(infile)
if split_index == -1: # for unit tests
return
+
+ # fetch name of workdir
spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
- sys.path.append(spark_files_dir)
+
+ # fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+
+ # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
+ sys.path.append(spark_files_dir) # *.py files that were added will be copied here
+ num_python_includes = read_int(infile)
+ for _ in range(num_python_includes):
+ sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+
+ # now load function
func = load_obj(infile)
bypassSerializer = load_obj(infile)
if bypassSerializer:
diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg
new file mode 100644
index 0000000000..1674c9cb22
--- /dev/null
+++ b/python/test_support/userlib-0.1-py2.7.egg
Binary files differ