aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-21 16:42:24 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-21 17:34:17 -0800
commitef711902c1f42db14c8ddd524195f0a9efb56e65 (patch)
treee770a7439d3983c13346cbd81aa1eeeef23e2571 /python/pyspark
parent506077c9938cd411842fe42404aa6b74b45b23a1 (diff)
downloadspark-ef711902c1f42db14c8ddd524195f0a9efb56e65.tar.gz
spark-ef711902c1f42db14c8ddd524195f0a9efb56e65.tar.bz2
spark-ef711902c1f42db14c8ddd524195f0a9efb56e65.zip
Don't download files to master's working directory.
This should avoid exceptions caused by existing files with different contents. I also removed some unused code.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/__init__.py5
-rw-r--r--python/pyspark/context.py40
-rw-r--r--python/pyspark/files.py24
-rw-r--r--python/pyspark/worker.py3
4 files changed, 67 insertions, 5 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 00666bc0a3..3e8bca62f0 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -11,6 +11,8 @@ Public classes:
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
"""
import sys
import os
@@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg
from pyspark.context import SparkContext
from pyspark.rdd import RDD
+from pyspark.files import SparkFiles
-__all__ = ["SparkContext", "RDD"]
+__all__ = ["SparkContext", "RDD", "SparkFiles"]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index dcbed37270..ec0cc7c2f9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,5 +1,7 @@
import os
import atexit
+import shutil
+import tempfile
from tempfile import NamedTemporaryFile
from pyspark import accumulators
@@ -173,10 +175,26 @@ class SparkContext(object):
def addFile(self, path):
"""
- Add a file to be downloaded into the working directory of this Spark
- job on every node. The C{path} passed can be either a local file,
- a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
- HTTPS or FTP URI.
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
"""
self._jsc.sc().addFile(path)
@@ -211,3 +229,17 @@ class SparkContext(object):
accidental overriding of checkpoint files in the existing directory.
"""
self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
new file mode 100644
index 0000000000..de1334f046
--- /dev/null
+++ b/python/pyspark/files.py
@@ -0,0 +1,24 @@
+import os
+
+
+class SparkFiles(object):
+ """
+ Resolves paths to files added through
+ L{addFile()<pyspark.context.SparkContext.addFile>}.
+
+ SparkFiles contains only classmethods; users should not create SparkFiles
+ instances.
+ """
+
+ _root_directory = None
+
+ def __init__(self):
+ raise NotImplementedError("Do not construct SparkFiles objects")
+
+ @classmethod
+ def get(cls, filename):
+ """
+ Get the absolute path of a file added through C{addFile()}.
+ """
+ path = os.path.join(SparkFiles._root_directory, filename)
+ return os.path.abspath(path)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b2b9288089..e7bdb7682b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -8,6 +8,7 @@ from base64 import standard_b64decode
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
+from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
@@ -23,6 +24,8 @@ def load_obj():
def main():
split_index = read_int(sys.stdin)
+ spark_files_dir = load_pickle(read_with_length(sys.stdin))
+ SparkFiles._root_directory = spark_files_dir
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)