From c854b9fcb5595b1d70b6ce257fc7574602ac5e49 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 12:10:09 -0700 Subject: [SPARK-3634] [PySpark] User's module should take precedence over system modules Python modules added through addPyFile should take precedence over system modules. This patch put the path for user added module in the front of sys.path (just after ''). Author: Davies Liu Closes #2492 from davies/path and squashes the following commits: 4a2af78 [Davies Liu] fix tests f7ff4da [Davies Liu] ad license header 6b0002f [Davies Liu] add tests c16c392 [Davies Liu] put addPyFile in front of sys.path --- python/pyspark/context.py | 11 +++++------ python/pyspark/tests.py | 12 ++++++++++++ python/pyspark/worker.py | 11 +++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 064a24bff5..8e7b00469e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -171,7 +171,7 @@ class SparkContext(object): SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() - sys.path.append(root_dir) + sys.path.insert(1, root_dir) # Deploy any code dependencies specified in the constructor self._python_includes = list() @@ -183,10 +183,9 @@ class SparkContext(object): for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self._python_includes.append(filename) - sys.path.append(path) - if dirname not in sys.path: - sys.path.append(dirname) + if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + self._python_includes.append(filename) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -667,7 +666,7 @@ class SparkContext(object): if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) # for tests in local mode - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1b8afb763b..4483bf80db 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,18 @@ class TestAddFile(PySparkTestCase): from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + class TestRDDFunctions(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d6c06e2dbe..c1f6e3e4a1 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish): write_long(1000 * finish, outfile) +def add_path(path): + # worker can be used, so donot add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + def main(infile, outfile): try: boot_time = time.time() @@ -61,11 +68,11 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here + add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) -- cgit v1.2.3