aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-10-02 17:31:01 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-10-02 17:31:01 -0700
commit97cbd699d73130525d7bf004d7dc47233c33ed52 (patch)
tree319bb2b8a874dd49e91382f9dd0ecfd1ab04c037 /core/src/main
parent5fda59ab990f37a3633f88b8d4d15ce96df08266 (diff)
parentc8ca6bc59b1b7b4798e5815a23551801061623ca (diff)
downloadspark-97cbd699d73130525d7bf004d7dc47233c33ed52.tar.gz
spark-97cbd699d73130525d7bf004d7dc47233c33ed52.tar.bz2
spark-97cbd699d73130525d7bf004d7dc47233c33ed52.zip
Merge branch 'dev' of github.com:mesos/spark into dev
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/SparkContext.scala83
1 files changed, 47 insertions, 36 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 79a9e8e34e..1ef1712c56 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -55,7 +55,7 @@ class SparkContext(
val sparkHome: String,
val jars: Seq[String])
extends Logging {
-
+
def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
// Ensure logging is initialized before we spawn any threads
@@ -78,30 +78,30 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
-
+
// Used to store a URL for each static file/jar together with the file's local timestamp
val addedFiles = HashMap[String, Long]()
val addedJars = HashMap[String, Long]()
-
+
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
-
+
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
- val LOCAL_CLUSTER_REGEX = """local-cluster\[([0-9]+),([0-9]+),([0-9]+)]""".r
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
-
+
master match {
- case "local" =>
+ case "local" =>
new LocalScheduler(1, 0, this)
- case LOCAL_N_REGEX(threads) =>
+ case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
@@ -112,10 +112,21 @@ class SparkContext(
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
scheduler
-
- case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) =>
+
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ val sparkMemEnv = System.getenv("SPARK_MEM")
+ val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
+ if (sparkMemEnvInt > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
+ memoryPerSlaveInt, sparkMemEnvInt))
+ }
+
val scheduler = new ClusterScheduler(this)
- val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
@@ -140,13 +151,13 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
-
+
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
}
-
+
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
parallelize(seq, numSlices)
}
@@ -187,14 +198,14 @@ class SparkContext(
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
- fm.erasure.asInstanceOf[Class[F]],
+ fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
minSplits)
@@ -215,7 +226,7 @@ class SparkContext(
new Configuration)
}
- /**
+ /**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@@ -231,7 +242,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
}
- /**
+ /**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@@ -257,14 +268,14 @@ class SparkContext(
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
- * Version of sequenceFile() for types implicitly convertible to Writables through a
+ * Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter.
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
- * both subclasses of Writable and types for which we define a converter (e.g. Int to
+ * both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
- * have a parameterized singleton object). We use functions instead to create a new converter
+ * have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
@@ -289,7 +300,7 @@ class SparkContext(
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T: ClassManifest](
- path: String,
+ path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
@@ -318,7 +329,7 @@ class SparkContext(
/**
* Create an accumulator from a "mutable collection" type.
- *
+ *
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
@@ -329,7 +340,7 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
-
+
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) {
val uri = new URI(path)
@@ -338,11 +349,11 @@ class SparkContext(
case _ => path
}
addedFiles(key) = System.currentTimeMillis
-
+
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File("."))
-
+
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -350,7 +361,7 @@ class SparkContext(
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
-
+
// Adds a jar dependency to all Tasks executed in the future.
def addJar(path: String) {
val uri = new URI(path)
@@ -366,7 +377,7 @@ class SparkContext(
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
-
+
// Stop the SparkContext
def stop() {
dagScheduler.stop()
@@ -400,7 +411,7 @@ class SparkContext(
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the master rather than shipping it out to the
+ * whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
@@ -419,13 +430,13 @@ class SparkContext(
def runJob[T, U: ClassManifest](
rdd: RDD[T],
- func: Iterator[T] => U,
+ func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}
-
+
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
@@ -472,7 +483,7 @@ class SparkContext(
private[spark] def newShuffleId(): Int = {
nextShuffleId.getAndIncrement()
}
-
+
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
@@ -500,7 +511,7 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
@@ -521,7 +532,7 @@ object SparkContext {
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
-
+
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
@@ -532,7 +543,7 @@ object SparkContext {
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
-
+
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
@@ -576,7 +587,7 @@ object SparkContext {
Nil
}
}
-
+
// Find the JAR that contains the class of a particular object
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
}