diff options
author | Ye Xianjin <advancedxy@gmail.com> | 2014-08-01 00:34:39 -0700 |
---|---|---|
committer | Aaron Davidson <aaron@databricks.com> | 2014-08-01 00:34:39 -0700 |
commit | 284771efbef2d6b22212afd49dd62732a2cf52a8 (patch) | |
tree | 7c94184089aef254539a1783c6f1e27c95eb4ee9 /core | |
parent | f1957e11652a537efd40771f843591a4c9341014 (diff) | |
download | spark-284771efbef2d6b22212afd49dd62732a2cf52a8.tar.gz spark-284771efbef2d6b22212afd49dd62732a2cf52a8.tar.bz2 spark-284771efbef2d6b22212afd49dd62732a2cf52a8.zip |
[Spark 2557] fix LOCAL_N_REGEX in createTaskScheduler and make local-n and local-n-failures consistent
[SPARK-2557](https://issues.apache.org/jira/browse/SPARK-2557)
Author: Ye Xianjin <advancedxy@gmail.com>
Closes #1464 from advancedxy/SPARK-2557 and squashes the following commits:
d844d67 [Ye Xianjin] add local-*-n-failures, bad-local-n, bad-local-n-failures test case
3bbc668 [Ye Xianjin] fix LOCAL_N_REGEX regular expression and make local_n_failures accept * as all cores on the computer
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/SparkContext.scala | 10 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala | 23 |
2 files changed, 30 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f5a0549834..0e513568b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1452,9 +1452,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1484,8 +1484,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 67e3be21c3..4b727e50db 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite } } + test("local-*-n-failures") { + val sched = createTaskScheduler("local[* ,2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n-failures") { val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite } } + test("bad-local-n") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("bad-local-n-failures") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*,4]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + test("local-default-parallelism") { val defaultParallelism = System.getProperty("spark.default.parallelism") System.setProperty("spark.default.parallelism", "16") |