aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-21 16:42:24 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-21 17:34:17 -0800
commitef711902c1f42db14c8ddd524195f0a9efb56e65 (patch)
treee770a7439d3983c13346cbd81aa1eeeef23e2571
parent506077c9938cd411842fe42404aa6b74b45b23a1 (diff)
downloadspark-ef711902c1f42db14c8ddd524195f0a9efb56e65.tar.gz
spark-ef711902c1f42db14c8ddd524195f0a9efb56e65.tar.bz2
spark-ef711902c1f42db14c8ddd524195f0a9efb56e65.zip
Don't download files to master's working directory.
This should avoid exceptions caused by existing files with different contents. I also removed some unused code.
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala8
-rw-r--r--core/src/main/scala/spark/SparkContext.scala7
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala20
-rw-r--r--core/src/main/scala/spark/SparkFiles.java25
-rw-r--r--core/src/main/scala/spark/Utils.scala16
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala5
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala5
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala6
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala9
-rw-r--r--python/pyspark/__init__.py5
-rw-r--r--python/pyspark/context.py40
-rw-r--r--python/pyspark/files.py24
-rw-r--r--python/pyspark/worker.py3
-rwxr-xr-xpython/run-tests3
16 files changed, 133 insertions, 51 deletions
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 659d17718f..00901d95e2 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 8b6f4b3b7d..2eeca66ed6 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -439,9 +439,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@@ -454,7 +455,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
- Utils.fetchFile(path, new File("."))
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 41441720a7..6b44e29f4c 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -28,14 +28,10 @@ class SparkEnv (
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
- val httpFileServer: HttpFileServer
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String
) {
- /** No-parameter constructor for unit tests. */
- def this() = {
- this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
- }
-
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
@@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isMaster) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -128,6 +133,7 @@ object SparkEnv extends Logging {
broadcastManager,
blockManager,
connectionManager,
- httpFileServer)
+ httpFileServer,
+ sparkFilesDir)
}
}
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
new file mode 100644
index 0000000000..b59d8ce93f
--- /dev/null
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -0,0 +1,25 @@
+package spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `addFile().
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 692a3f4050..827c8bd81e 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -111,20 +111,6 @@ private object Utils extends Logging {
}
}
- /** Copy a file on the local file system */
- def copyFile(source: File, dest: File) {
- val in = new FileInputStream(source)
- val out = new FileOutputStream(dest)
- copyStream(in, out, true)
- }
-
- /** Download a file from a given URL to the local filesystem */
- def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
-
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@@ -201,7 +187,7 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
- FileUtil.chmod(filename, "a+x")
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}
/**
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index 16c122c584..50b8970cd8 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome()
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 5526406a20..f43a152ca7 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest](
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index beceb55ecd..0d1fe2a6b4 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
- // Download the files it depends on into it (disabled for now)
- //for (url <- jobDesc.fileUrls) {
- // fetchFile(url, executorDir)
- //}
-
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 2552958d27..70629f6003 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -162,16 +162,16 @@ private[spark] class Executor extends Logging {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index dff550036d..4451d314e6 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..528c6b8424 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 00666bc0a3..3e8bca62f0 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -11,6 +11,8 @@ Public classes:
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
"""
import sys
import os
@@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg
from pyspark.context import SparkContext
from pyspark.rdd import RDD
+from pyspark.files import SparkFiles
-__all__ = ["SparkContext", "RDD"]
+__all__ = ["SparkContext", "RDD", "SparkFiles"]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index dcbed37270..ec0cc7c2f9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,5 +1,7 @@
import os
import atexit
+import shutil
+import tempfile
from tempfile import NamedTemporaryFile
from pyspark import accumulators
@@ -173,10 +175,26 @@ class SparkContext(object):
def addFile(self, path):
"""
- Add a file to be downloaded into the working directory of this Spark
- job on every node. The C{path} passed can be either a local file,
- a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
- HTTPS or FTP URI.
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
"""
self._jsc.sc().addFile(path)
@@ -211,3 +229,17 @@ class SparkContext(object):
accidental overriding of checkpoint files in the existing directory.
"""
self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
new file mode 100644
index 0000000000..de1334f046
--- /dev/null
+++ b/python/pyspark/files.py
@@ -0,0 +1,24 @@
+import os
+
+
+class SparkFiles(object):
+ """
+ Resolves paths to files added through
+ L{addFile()<pyspark.context.SparkContext.addFile>}.
+
+ SparkFiles contains only classmethods; users should not create SparkFiles
+ instances.
+ """
+
+ _root_directory = None
+
+ def __init__(self):
+ raise NotImplementedError("Do not construct SparkFiles objects")
+
+ @classmethod
+ def get(cls, filename):
+ """
+ Get the absolute path of a file added through C{addFile()}.
+ """
+ path = os.path.join(SparkFiles._root_directory, filename)
+ return os.path.abspath(path)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b2b9288089..e7bdb7682b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -8,6 +8,7 @@ from base64 import standard_b64decode
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
+from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
@@ -23,6 +24,8 @@ def load_obj():
def main():
split_index = read_int(sys.stdin)
+ spark_files_dir = load_pickle(read_with_length(sys.stdin))
+ SparkFiles._root_directory = spark_files_dir
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
diff --git a/python/run-tests b/python/run-tests
index ce214e98a8..a3a9ff5dcb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -8,6 +8,9 @@ FAILED=0
$FWDIR/pyspark pyspark/rdd.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark pyspark/context.py
+FAILED=$(($?||$FAILED))
+
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))