aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/FileSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala19
3 files changed, 22 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a1003b7925..8f74607278 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1262,7 +1262,10 @@ object SparkContext extends Logging {
master match {
case "local" =>
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
- val backend = new LocalBackend(scheduler, 1)
+ // Use user specified in config, up to all available cores
+ val realCores = Runtime.getRuntime.availableProcessors()
+ val toUseCores = math.min(sc.conf.getInt("spark.cores.max", realCores), realCores)
+ val backend = new LocalBackend(scheduler, toUseCores)
scheduler.initialize(backend)
scheduler
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 01af940771..b4a5881cd9 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.SparkContext._
class FileSuite extends FunSuite with LocalSparkContext {
test("text files") {
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local[1]", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
@@ -176,7 +176,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
test("write SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local[1]", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index b543471a5d..9dd42be1d7 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -27,10 +27,10 @@ import org.apache.spark.scheduler.local.LocalBackend
class SparkContextSchedulerCreationSuite
extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
- def createTaskScheduler(master: String): TaskSchedulerImpl = {
+ def createTaskScheduler(master: String, conf: SparkConf = new SparkConf()): TaskSchedulerImpl = {
// Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
// real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test", conf)
val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master)
sched.asInstanceOf[TaskSchedulerImpl]
@@ -44,13 +44,26 @@ class SparkContextSchedulerCreationSuite
}
test("local") {
- val sched = createTaskScheduler("local")
+ var conf = new SparkConf()
+ conf.set("spark.cores.max", "1")
+ val sched = createTaskScheduler("local", conf)
sched.backend match {
case s: LocalBackend => assert(s.totalCores === 1)
case _ => fail()
}
}
+ test("local-cores-exceed") {
+ val cores = Runtime.getRuntime.availableProcessors() + 1
+ var conf = new SparkConf()
+ conf.set("spark.cores.max", cores.toString)
+ val sched = createTaskScheduler("local", conf)
+ sched.backend match {
+ case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors())
+ case _ => fail()
+ }
+ }
+
test("local-n") {
val sched = createTaskScheduler("local[5]")
assert(sched.maxTaskFailures === 1)