aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/SparkFiles.java8
-rw-r--r--python/pyspark/context.py27
-rw-r--r--python/pyspark/files.py20
-rw-r--r--python/pyspark/tests.py23
-rw-r--r--python/pyspark/worker.py1
-rwxr-xr-xpython/test_support/hello.txt1
6 files changed, 67 insertions, 13 deletions
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
index b59d8ce93f..566aec622c 100644
--- a/core/src/main/scala/spark/SparkFiles.java
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -3,23 +3,23 @@ package spark;
import java.io.File;
/**
- * Resolves paths to files added through `addFile().
+ * Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {
private SparkFiles() {}
/**
- * Get the absolute path of a file added through `addFile()`.
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}
/**
- * Get the root directory that contains files added through `addFile()`.
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
-} \ No newline at end of file
+}
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b8d7dc05af..3e33776af0 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,12 +1,15 @@
import os
import atexit
import shutil
+import sys
import tempfile
+from threading import Lock
from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
+from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
@@ -27,6 +30,8 @@ class SparkContext(object):
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
_next_accum_id = 0
+ _active_spark_context = None
+ _lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
@@ -46,6 +51,11 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
+ with SparkContext._lock:
+ if SparkContext._active_spark_context:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = self
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -75,6 +85,8 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
+ SparkFiles._sc = self
+ sys.path.append(SparkFiles.getRootDirectory())
@property
def defaultParallelism(self):
@@ -85,17 +97,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism()
def __del__(self):
- if self._jsc:
- self._jsc.stop()
- if self._accumulatorServer:
- self._accumulatorServer.shutdown()
+ self.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
- self._jsc.stop()
- self._jsc = None
+ if self._jsc:
+ self._jsc.stop()
+ self._jsc = None
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
+ self._accumulatorServer = None
+ with SparkContext._lock:
+ SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None):
"""
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index de1334f046..98f6a399cc 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -4,13 +4,15 @@ import os
class SparkFiles(object):
"""
Resolves paths to files added through
- L{addFile()<pyspark.context.SparkContext.addFile>}.
+ L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
SparkFiles contains only classmethods; users should not create SparkFiles
instances.
"""
_root_directory = None
+ _is_running_on_worker = False
+ _sc = None
def __init__(self):
raise NotImplementedError("Do not construct SparkFiles objects")
@@ -18,7 +20,19 @@ class SparkFiles(object):
@classmethod
def get(cls, filename):
"""
- Get the absolute path of a file added through C{addFile()}.
+ Get the absolute path of a file added through C{SparkContext.addFile()}.
"""
- path = os.path.join(SparkFiles._root_directory, filename)
+ path = os.path.join(SparkFiles.getRootDirectory(), filename)
return os.path.abspath(path)
+
+ @classmethod
+ def getRootDirectory(cls):
+ """
+ Get the root directory that contains files added through
+ C{SparkContext.addFile()}.
+ """
+ if cls._is_running_on_worker:
+ return cls._root_directory
+ else:
+ # This will have to change if we support multiple SparkContexts:
+ return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 4d70ee4f12..46ab34f063 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -4,22 +4,26 @@ individual modules.
"""
import os
import shutil
+import sys
from tempfile import NamedTemporaryFile
import time
import unittest
from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
class PySparkTestCase(unittest.TestCase):
def setUp(self):
+ self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
+ sys.path = self._old_sys_path
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
@@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase):
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4bf643da66..d33d6dd15f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir
+ SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt
new file mode 100755
index 0000000000..980a0d5f19
--- /dev/null
+++ b/python/test_support/hello.txt
@@ -0,0 +1 @@
+Hello World!