From 35168d9c89904f0dc0bb470c1799f5ca3b04221f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 17:54:11 -0800 Subject: Fix sys.path bug in PySpark SparkContext.addPyFile --- python/pyspark/context.py | 2 -- python/pyspark/tests.py | 38 +++++++++++++++++++++++++++++++++----- python/pyspark/worker.py | 1 + python/test_support/userlibrary.py | 7 +++++++ 4 files changed, 41 insertions(+), 7 deletions(-) create mode 100755 python/test_support/userlibrary.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ec0cc7c2f9..b8d7dc05af 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -215,8 +215,6 @@ class SparkContext(object): """ self.addFile(path) filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) def setCheckpointDir(self, dirName, useExisting=False): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0a403b580..4d70ee4f12 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -9,21 +9,32 @@ import time import unittest from pyspark.context import SparkContext +from pyspark.java_gateway import SPARK_HOME -class TestCheckpoint(unittest.TestCase): +class PySparkTestCase(unittest.TestCase): def setUp(self): - self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) - self.checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + + +class TestCheckpoint(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase): self.assertEquals([1, 2, 3, 4], recovered.collect()) +class TestAddFile(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + def func(x): + from userlibrary import UserClass + return UserClass().hello() + self.assertRaises(Exception, + self.sc.parallelize(range(2)).map(func).first) + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e7bdb7682b..4bf643da66 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py new file mode 100755 index 0000000000..5bb6f5009f --- /dev/null +++ b/python/test_support/userlibrary.py @@ -0,0 +1,7 @@ +""" +Used to test shipping of code depenencies with SparkContext.addPyFile(). +""" + +class UserClass(object): + def hello(self): + return "Hello World!" -- cgit v1.2.3