diff options
-rw-r--r-- | python/pyspark/context.py | 11 | ||||
-rw-r--r-- | python/pyspark/tests.py | 12 | ||||
-rw-r--r-- | python/pyspark/worker.py | 11 | ||||
-rw-r--r-- | python/test_support/SimpleHTTPServer.py | 22 |
4 files changed, 48 insertions, 8 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 064a24bff5..8e7b00469e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -171,7 +171,7 @@ class SparkContext(object): SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() - sys.path.append(root_dir) + sys.path.insert(1, root_dir) # Deploy any code dependencies specified in the constructor self._python_includes = list() @@ -183,10 +183,9 @@ class SparkContext(object): for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self._python_includes.append(filename) - sys.path.append(path) - if dirname not in sys.path: - sys.path.append(dirname) + if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + self._python_includes.append(filename) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -667,7 +666,7 @@ class SparkContext(object): if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) # for tests in local mode - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1b8afb763b..4483bf80db 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,18 @@ class TestAddFile(PySparkTestCase): from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + class TestRDDFunctions(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d6c06e2dbe..c1f6e3e4a1 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish): write_long(1000 * finish, outfile) +def add_path(path): + # worker can be used, so donot add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + def main(infile, outfile): try: boot_time = time.time() @@ -61,11 +68,11 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here + add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) diff --git a/python/test_support/SimpleHTTPServer.py b/python/test_support/SimpleHTTPServer.py new file mode 100644 index 0000000000..eddbd588e0 --- /dev/null +++ b/python/test_support/SimpleHTTPServer.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Used to test override standard SimpleHTTPServer module. +""" + +__name__ = "My Server" |