aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDenny <dennybritz@gmail.com>2012-08-30 11:01:43 -0700
committerDenny <dennybritz@gmail.com>2012-09-10 12:48:59 -0700
commitf275fb07da33cfa38fc02ed121a52caef20f61d0 (patch)
tree47963ae4d93a19f7082f4d854c1c4fa4a960f453
parenta13780670d8810a9fb52d8cc4e42d3c5155a8d1d (diff)
downloadspark-f275fb07da33cfa38fc02ed121a52caef20f61d0.tar.gz
spark-f275fb07da33cfa38fc02ed121a52caef20f61d0.tar.bz2
spark-f275fb07da33cfa38fc02ed121a52caef20f61d0.zip
General FileServer
A general fileserver for both JARs and regular files.
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala31
-rw-r--r--core/src/main/scala/spark/SparkContext.scala35
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala15
-rw-r--r--core/src/main/scala/spark/Utils.scala50
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala32
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala15
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala23
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala9
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala43
10 files changed, 204 insertions, 71 deletions
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
new file mode 100644
index 0000000000..3659de02c7
--- /dev/null
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -0,0 +1,31 @@
+package spark
+
+import java.io.{File, PrintWriter}
+import java.net.URL
+import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.FileUtil
+
+class HttpFileServer extends Logging {
+
+ var fileDir : File = null
+ var httpServer : HttpServer = null
+ var serverUri : String = null
+
+ def initialize() {
+ fileDir = Utils.createTempDir()
+ logInfo("HTTP File server directory is " + fileDir)
+ httpServer = new HttpServer(fileDir)
+ httpServer.start()
+ serverUri = httpServer.uri
+ }
+
+ def addFile(file: File) : String = {
+ Utils.copyFile(file, new File(fileDir, file.getName))
+ return serverUri + "/" + file.getName
+ }
+
+ def stop() {
+ httpServer.stop()
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 5d0f2950d6..dee7cd4925 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -2,11 +2,12 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
+import java.net.URI
import akka.actor.Actor
import akka.actor.Actor._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -76,7 +77,10 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
-
+
+ // Used to store a URL for each static file together with the file's local timestamp
+ val files = HashMap[String, Long]()
+
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -90,13 +94,13 @@ class SparkContext(
master match {
case "local" =>
- new LocalScheduler(1, 0)
+ new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0)
+ new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt)
+ new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
@@ -131,7 +135,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
-
+
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
@@ -310,7 +314,24 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
-
+
+ // Adds a file dependency to all Tasks executed in the future.
+ def addFile(path: String) : String = {
+ val uri = new URI(path)
+ uri.getScheme match {
+ // A local file
+ case null | "file" =>
+ val file = new File(uri.getPath)
+ val url = env.httpFileServer.addFile(file)
+ files(url) = System.currentTimeMillis
+ logInfo("Added file " + path + " at " + url + " with timestamp " + files(url))
+ return url
+ case _ =>
+ files(path) = System.currentTimeMillis
+ return path
+ }
+ }
+
// Stop the SparkContext
def stop() {
dagScheduler.stop()
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index add8fcec51..a95d1bc8ea 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,15 +19,17 @@ class SparkEnv (
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
- val connectionManager: ConnectionManager
+ val connectionManager: ConnectionManager,
+ val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
- this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
+ this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null, null)
}
def stop() {
+ httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
@@ -95,7 +97,11 @@ object SparkEnv {
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
-
+
+ val httpFileServer = new HttpFileServer()
+ httpFileServer.initialize()
+ System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+
/*
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
@@ -126,6 +132,7 @@ object SparkEnv {
shuffleManager,
broadcastManager,
blockManager,
- connectionManager)
+ connectionManager,
+ httpFileServer)
}
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 5eda1011f9..eb0a4c99bb 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,18 +1,19 @@
package spark
import java.io._
-import java.net.InetAddress
+import java.net.{InetAddress, URL, URI}
+import java.util.{Locale, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
-import java.util.{Locale, UUID}
import scala.io.Source
/**
* Various utility methods used by Spark.
*/
-object Utils {
+object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -115,6 +116,47 @@ object Utils {
val out = new FileOutputStream(dest)
copyStream(in, out, true)
}
+
+
+
+ /* Download a file from a given URL to the local filesystem */
+ def downloadFile(url: URL, localPath: String) {
+ val in = url.openStream()
+ val out = new FileOutputStream(localPath)
+ Utils.copyStream(in, out, true)
+ }
+
+ /**
+ * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ */
+ def fetchFile(url: String, targetDir: File) {
+ val filename = url.split("/").last
+ val targetFile = new File(targetDir, filename)
+ if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
+ // Use the java.net library to fetch it
+ logInfo("Fetching " + url + " to " + targetFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ } else {
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val uri = new URI(url)
+ val conf = new Configuration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ }
+ // Decompress the file if it's a .tar or .tar.gz
+ if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xzf", filename), targetDir)
+ } else if (filename.endsWith(".tar")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ }
+ }
/**
* Shuffle the elements of a collection into a random order, returning the
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 1740a42a7e..7043361020 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -65,38 +65,6 @@ class ExecutorRunner(
}
}
- /**
- * Download a file requested by the executor. Supports fetching the file in a variety of ways,
- * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
- */
- def fetchFile(url: String, targetDir: File) {
- val filename = url.split("/").last
- val targetFile = new File(targetDir, filename)
- if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
- // Use the java.net library to fetch it
- logInfo("Fetching " + url + " to " + targetFile)
- val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- } else {
- // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val uri = new URI(url)
- val conf = new Configuration()
- val fs = FileSystem.get(uri, conf)
- val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- }
- // Decompress the file if it's a .tar or .tar.gz
- if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xzf", filename), targetDir)
- } else if (filename.endsWith(".tar")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xf", filename), targetDir)
- }
- }
-
/** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{SLAVEID}}" => workerId
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index dba209ac27..ce3aa49726 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -4,7 +4,9 @@ import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader}
import java.util.concurrent._
-import scala.collection.mutable.ArrayBuffer
+import org.apache.hadoop.fs.FileUtil
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark.broadcast._
import spark.scheduler._
@@ -18,6 +20,8 @@ class Executor extends Logging {
var classLoader: ClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
+
+ val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@@ -63,6 +67,7 @@ class Executor extends Logging {
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
+ task.downloadFileDependencies(fileSet)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@@ -108,7 +113,7 @@ class Executor extends Logging {
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
- downloadFile(url, filename)
+ Utils.downloadFile(url, filename)
localFiles += filename
}
if (localFiles.size > 0) {
@@ -136,10 +141,4 @@ class Executor extends Logging {
return loader
}
- // Download a file from a given URL to the local filesystem
- private def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index f84d8d9c4f..faf042ad02 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,5 +1,10 @@
package spark.scheduler
+import scala.collection.mutable.HashMap
+import spark.HttpFileServer
+import spark.Utils
+import java.io.File
+
/**
* A task to execute on a worker node.
*/
@@ -8,4 +13,21 @@ abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+
+ // Stores file dependencies for this task.
+ var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
+
+ // Downloads all file dependencies from the Master file server
+ def downloadFileDependencies(currentFileSet : HashMap[String, Long]) {
+ // Find files that either don't exist or have an earlier timestamp
+ val missingFiles = fileSet.filter { case(k,v) =>
+ !currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v
+ }
+ // Fetch each missing file
+ missingFiles.foreach { case (k,v) =>
+ Utils.fetchFile(k, new File(System.getProperty("user.dir")))
+ currentFileSet(k) = v
+ }
+ }
+
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 5b59479682..a9ab82040c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -88,6 +88,7 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
+ tasks.foreach { task => task.fileSet ++= sc.files }
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
@@ -235,30 +236,24 @@ class ClusterScheduler(sc: SparkContext)
}
override def defaultParallelism() = backend.defaultParallelism()
-
- // Create a server for all the JARs added by the user to SparkContext.
- // We first copy the JARs to a temp directory for easier server setup.
+
+ // Copies all the JARs added by the user to the SparkContext
+ // to the fileserver directory.
private def createJarServer() {
- val jarDir = Utils.createTempDir()
- logInfo("Temp directory for JARs: " + jarDir)
+ val fileServerDir = SparkEnv.get.httpFileServer.fileDir
+ val fileServerUri = SparkEnv.get.httpFileServer.serverUri
val filenames = ArrayBuffer[String]()
- // Copy each JAR to a unique filename in the jarDir
for ((path, index) <- sc.jars.zipWithIndex) {
val file = new File(path)
if (file.exists) {
val filename = index + "_" + file.getName
- Utils.copyFile(file, new File(jarDir, filename))
+ Utils.copyFile(file, new File(fileServerDir, filename))
filenames += filename
}
}
- // Create the server
- jarServer = new HttpServer(jarDir)
- jarServer.start()
- // Build up the jar URI list
- val serverUri = jarServer.uri
- jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
+ jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",")
System.setProperty("spark.jar.uris", jarUris)
- logInfo("JAR server started at " + serverUri)
+ logInfo("JARs available at " + jarUris)
}
// Check for speculatable tasks in all our active jobs.
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb47988f0c..4bd9d13637 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -2,6 +2,7 @@ package spark.scheduler.local
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.mutable.HashMap
import spark._
import spark.scheduler._
@@ -11,12 +12,13 @@ import spark.scheduler._
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
-
+ val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
+
// TODO: Need to take into account stage priority in scheduling
override def start() {}
@@ -30,6 +32,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
+ task.fileSet ++= sc.files
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
@@ -42,6 +45,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
+ task.downloadFileDependencies(fileSet)
try {
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
@@ -81,6 +85,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
}
}
+
override def stop() {
threadPool.shutdownNow()
}
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
new file mode 100644
index 0000000000..883149feca
--- /dev/null
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -0,0 +1,43 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import java.io.{File, PrintWriter}
+
+class FileServerSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ before {
+ // Create a sample text file
+ val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
+ pw.println("100")
+ pw.close()
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // Clean up downloaded file
+ val tmpFile = new File("FileServerSuite.txt")
+ if (tmpFile.exists) {
+ tmpFile.delete()
+ }
+ }
+
+ test("Distributing files") {
+ sc = new SparkContext("local[4]", "test")
+ sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
+ val testRdd = sc.parallelize(List(1,2,3,4))
+ val result = testRdd.map { x =>
+ val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt"))
+ val fileVal = in.readLine().toInt
+ in.close()
+ fileVal
+ }.reduce(_ + _)
+ assert(result == 400)
+ }
+
+} \ No newline at end of file