aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-22 17:54:11 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-22 17:54:11 -0800
commit35168d9c89904f0dc0bb470c1799f5ca3b04221f (patch)
tree39bfa48768f4382ea1104506eed4ff822c881c6c /python
parent7b9e96c99206c0679d9925e0161fde738a5c7c3a (diff)
downloadspark-35168d9c89904f0dc0bb470c1799f5ca3b04221f.tar.gz
spark-35168d9c89904f0dc0bb470c1799f5ca3b04221f.tar.bz2
spark-35168d9c89904f0dc0bb470c1799f5ca3b04221f.zip
Fix sys.path bug in PySpark SparkContext.addPyFile
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/context.py2
-rw-r--r--python/pyspark/tests.py38
-rw-r--r--python/pyspark/worker.py1
-rwxr-xr-xpython/test_support/userlibrary.py7
4 files changed, 41 insertions, 7 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ec0cc7c2f9..b8d7dc05af 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -215,8 +215,6 @@ class SparkContext(object):
"""
self.addFile(path)
filename = path.split("/")[-1]
- os.environ["PYTHONPATH"] = \
- "%s:%s" % (filename, os.environ["PYTHONPATH"])
def setCheckpointDir(self, dirName, useExisting=False):
"""
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b0a403b580..4d70ee4f12 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -9,21 +9,32 @@ import time
import unittest
from pyspark.context import SparkContext
+from pyspark.java_gateway import SPARK_HOME
-class TestCheckpoint(unittest.TestCase):
+class PySparkTestCase(unittest.TestCase):
def setUp(self):
- self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
- self.checkpointDir = NamedTemporaryFile(delete=False)
- os.unlink(self.checkpointDir.name)
- self.sc.setCheckpointDir(self.checkpointDir.name)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect())
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e7bdb7682b..4bf643da66 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir
+ sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
new file mode 100755
index 0000000000..5bb6f5009f
--- /dev/null
+++ b/python/test_support/userlibrary.py
@@ -0,0 +1,7 @@
+"""
+Used to test shipping of code depenencies with SparkContext.addPyFile().
+"""
+
+class UserClass(object):
+ def hello(self):
+ return "Hello World!"