aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/context.py11
-rw-r--r--python/pyspark/tests.py12
-rw-r--r--python/pyspark/worker.py11
-rw-r--r--python/test_support/SimpleHTTPServer.py22
4 files changed, 48 insertions, 8 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 064a24bff5..8e7b00469e 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -171,7 +171,7 @@ class SparkContext(object):
SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
- sys.path.append(root_dir)
+ sys.path.insert(1, root_dir)
# Deploy any code dependencies specified in the constructor
self._python_includes = list()
@@ -183,10 +183,9 @@ class SparkContext(object):
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
- self._python_includes.append(filename)
- sys.path.append(path)
- if dirname not in sys.path:
- sys.path.append(dirname)
+ if filename.lower().endswith("zip") or filename.lower().endswith("egg"):
+ self._python_includes.append(filename)
+ sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
@@ -667,7 +666,7 @@ class SparkContext(object):
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
# for tests in local mode
- sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
+ sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
def setCheckpointDir(self, dirName):
"""
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1b8afb763b..4483bf80db 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -323,6 +323,18 @@ class TestAddFile(PySparkTestCase):
from userlib import UserClass
self.assertEqual("Hello World from inside a package!", UserClass().hello())
+ def test_overwrite_system_module(self):
+ self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
+
+ import SimpleHTTPServer
+ self.assertEqual("My Server", SimpleHTTPServer.__name__)
+
+ def func(x):
+ import SimpleHTTPServer
+ return SimpleHTTPServer.__name__
+
+ self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
+
class TestRDDFunctions(PySparkTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d6c06e2dbe..c1f6e3e4a1 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish):
write_long(1000 * finish, outfile)
+def add_path(path):
+ # worker can be used, so donot add path multiple times
+ if path not in sys.path:
+ # overwrite system packages
+ sys.path.insert(1, path)
+
+
def main(infile, outfile):
try:
boot_time = time.time()
@@ -61,11 +68,11 @@ def main(infile, outfile):
SparkFiles._is_running_on_worker = True
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
- sys.path.append(spark_files_dir) # *.py files that were added will be copied here
+ add_path(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
- sys.path.append(os.path.join(spark_files_dir, filename))
+ add_path(os.path.join(spark_files_dir, filename))
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
diff --git a/python/test_support/SimpleHTTPServer.py b/python/test_support/SimpleHTTPServer.py
new file mode 100644
index 0000000000..eddbd588e0
--- /dev/null
+++ b/python/test_support/SimpleHTTPServer.py
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Used to test override standard SimpleHTTPServer module.
+"""
+
+__name__ = "My Server"