aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala131
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala32
-rw-r--r--pom.xml12
-rw-r--r--project/SparkBuild.scala1
10 files changed, 264 insertions, 35 deletions
diff --git a/core/pom.xml b/core/pom.xml
index aac0a9d11e..1c52b334d0 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<artifactId>akka-slf4j_${scala.binary.version}</artifactId>
</dependency>
<dependency>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-testkit_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index f726089faa..d13d7eff09 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.deploy.worker
import java.io._
+import scala.collection.JavaConversions._
import scala.collection.mutable.Map
import akka.actor.ActorRef
@@ -47,6 +48,16 @@ private[spark] class DriverRunner(
@volatile var process: Option[Process] = None
@volatile var killed = false
+ // Decoupled for testing
+ private[deploy] def setClock(_clock: Clock) = clock = _clock
+ private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper
+ private var clock = new Clock {
+ def currentTimeMillis(): Long = System.currentTimeMillis()
+ }
+ private var sleeper = new Sleeper {
+ def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed})
+ }
+
/** Starts a thread to run and manage the driver. */
def start() = {
new Thread("DriverRunner for " + driverId) {
@@ -63,10 +74,9 @@ private[spark] class DriverRunner(
env("SPARK_CLASSPATH") = env.getOrElse("SPARK_CLASSPATH", "") + s":$localJarFilename"
val newCommand = Command(driverDesc.command.mainClass,
driverDesc.command.arguments.map(substituteVariables), env)
-
val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
sparkHome.getAbsolutePath)
- runCommand(command, env, driverDir, driverDesc.supervise)
+ launchDriver(command, env, driverDir, driverDesc.supervise)
}
catch {
case e: Exception => exn = Some(e)
@@ -116,7 +126,7 @@ private[spark] class DriverRunner(
val jarPath = new Path(driverDesc.jarUrl)
- val emptyConf = new Configuration() // TODO: In docs explain it needs to be full HDFS path
+ val emptyConf = new Configuration()
val jarFileSystem = jarPath.getFileSystem(emptyConf)
val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
@@ -136,51 +146,77 @@ private[spark] class DriverRunner(
localJarFilename
}
- /** Launch the supplied command. */
- private def runCommand(command: Seq[String], envVars: Map[String, String], baseDir: File,
- supervise: Boolean) {
+ private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
+ supervise: Boolean) {
+ val builder = new ProcessBuilder(command: _*).directory(baseDir)
+ envVars.map{ case(k,v) => builder.environment().put(k, v) }
+
+ def initialize(process: Process) = {
+ // Redirect stdout and stderr to files
+ val stdout = new File(baseDir, "stdout")
+ CommandUtils.redirectStream(process.getInputStream, stdout)
+
+ val stderr = new File(baseDir, "stderr")
+ val header = "Launch Command: %s\n%s\n\n".format(
+ command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, Charsets.UTF_8)
+ CommandUtils.redirectStream(process.getErrorStream, stderr)
+ }
+ runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
+ }
+ private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit,
+ supervise: Boolean) {
// Time to wait between submission retries.
var waitSeconds = 1
// A run of this many seconds resets the exponential back-off.
- val successfulRunDuration = 1
+ val successfulRunDuration = 5
var keepTrying = !killed
while (keepTrying) {
- logInfo("Launch Command: " + command.mkString("\"", "\" \"", "\""))
- val builder = new ProcessBuilder(command: _*).directory(baseDir)
- envVars.map{ case(k,v) => builder.environment().put(k, v) }
+ logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
synchronized {
if (killed) { return }
-
- process = Some(builder.start())
-
- // Redirect stdout and stderr to files
- val stdout = new File(baseDir, "stdout")
- CommandUtils.redirectStream(process.get.getInputStream, stdout)
-
- val stderr = new File(baseDir, "stderr")
- val header = "Launch Command: %s\n%s\n\n".format(
- command.mkString("\"", "\" \"", "\""), "=" * 40)
- Files.append(header, stderr, Charsets.UTF_8)
- CommandUtils.redirectStream(process.get.getErrorStream, stderr)
+ process = Some(command.start())
+ initialize(process.get)
}
- val processStart = System.currentTimeMillis()
+ val processStart = clock.currentTimeMillis()
val exitCode = process.get.waitFor()
- if (System.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
+ if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
waitSeconds = 1
}
if (supervise && exitCode != 0 && !killed) {
- waitSeconds = waitSeconds * 2 // exponential back-off
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
- (0 until waitSeconds).takeWhile(f => {Thread.sleep(1000); !killed})
+ sleeper.sleep(waitSeconds)
+ waitSeconds = waitSeconds * 2 // exponential back-off
}
keepTrying = supervise && exitCode != 0 && !killed
}
}
}
+
+private[deploy] trait Clock {
+ def currentTimeMillis(): Long
+}
+
+private[deploy] trait Sleeper {
+ def sleep(seconds: Int)
+}
+
+// Needed because ProcessBuilder is a final class and cannot be mocked
+private[deploy] trait ProcessBuilderLike {
+ def start(): Process
+ def command: Seq[String]
+}
+
+private[deploy] object ProcessBuilderLike {
+ def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike {
+ def start() = processBuilder.start()
+ def command = processBuilder.command()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index fdc9a34886..a9cb998cc2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -98,6 +98,12 @@ private[spark] class ExecutorRunner(
case other => other
}
+ def getCommandSeq = {
+ val command = Command(appDesc.command.mainClass,
+ appDesc.command.arguments.map(substituteVariables), appDesc.command.environment)
+ CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
+ }
+
/**
* Download and run the executor described in our ApplicationDescription
*/
@@ -110,9 +116,7 @@ private[spark] class ExecutorRunner(
}
// Launch the process
- val fullCommand = new Command(appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables), appDesc.command.environment)
- val command = CommandUtils.buildCommandSeq(fullCommand, memory, sparkHome.getAbsolutePath)
+ val command = getCommandSeq
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
val builder = new ProcessBuilder(command: _*).directory(executorDir)
val env = builder.environment()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index f4184bc5db..0e0d0cd626 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -10,7 +10,8 @@ import org.apache.spark.deploy.DeployMessages.SendHeartbeat
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging {
+private[spark] class WorkerWatcher(workerUrl: String) extends Actor
+ with Logging {
override def preStart() {
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -19,10 +20,17 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging
worker ! SendHeartbeat // need to send a message here to initiate connection
}
+ // Used to avoid shutting down JVM during tests
+ private[deploy] var isShutDown = false
+ private[deploy] def setTesting(testing: Boolean) = isTesting = testing
+ private var isTesting = false
+
// Lets us filter events only from the worker's actor system
private val expectedHostPort = AddressFromURIString(workerUrl).hostPort
private def isWorker(address: Address) = address.hostPort == expectedHostPort
+ def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
+
override def receive = {
case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
logInfo(s"Successfully connected to $workerUrl")
@@ -32,12 +40,12 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor with Logging
// These logs may not be seen if the worker (and associated pipe) has died
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
logError(s"Error was: $cause")
- System.exit(-1)
+ exitNonZero()
case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
// This log message will never be seen
logError(s"Lost connection to worker actor $workerUrl. Exiting.")
- System.exit(-1)
+ exitNonZero()
case e: AssociationEvent =>
// pass through association events relating to other remote actor systems
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 372c9f4378..028196fe86 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -86,7 +86,7 @@ class JsonProtocolSuite extends FunSuite {
)
def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3,
- createDriverCommand())
+ false, createDriverCommand())
def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date())
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
new file mode 100644
index 0000000000..45dbcaffae
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -0,0 +1,131 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+
+import scala.collection.JavaConversions._
+
+import org.mockito.Mockito._
+import org.mockito.Matchers._
+import org.scalatest.FunSuite
+
+import org.apache.spark.deploy.{Command, DriverDescription}
+import org.mockito.stubbing.Answer
+import org.mockito.invocation.InvocationOnMock
+
+class DriverRunnerTest extends FunSuite {
+ private def createDriverRunner() = {
+ val command = new Command("mainClass", Seq(), Map())
+ val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
+ new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription,
+ null, "akka://1.2.3.4/worker/")
+ }
+
+ private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
+ val processBuilder = mock(classOf[ProcessBuilderLike])
+ when(processBuilder.command).thenReturn(Seq("mocked", "command"))
+ val process = mock(classOf[Process])
+ when(processBuilder.start()).thenReturn(process)
+ (processBuilder, process)
+ }
+
+ test("Process succeeds instantly") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ // One failure then a successful run
+ when(process.waitFor()).thenReturn(0)
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Process failing several times and then succeeding") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ // fail, fail, fail, success
+ when(process.waitFor()).thenReturn(-1).thenReturn(-1).thenReturn(-1).thenReturn(0)
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(4)).waitFor()
+ verify(sleeper, times(3)).sleep(anyInt())
+ verify(sleeper, times(1)).sleep(1)
+ verify(sleeper, times(1)).sleep(2)
+ verify(sleeper, times(1)).sleep(4)
+ }
+
+ test("Process doesn't restart if not supervised") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ when(process.waitFor()).thenReturn(-1)
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = false)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Process doesn't restart if killed") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ when(process.waitFor()).thenAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.kill()
+ -1
+ }
+ })
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Reset of backoff counter") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val clock = mock(classOf[Clock])
+ runner.setClock(clock)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+
+ when(process.waitFor())
+ .thenReturn(-1) // fail 1
+ .thenReturn(-1) // fail 2
+ .thenReturn(-1) // fail 3
+ .thenReturn(-1) // fail 4
+ .thenReturn(0) // success
+ when(clock.currentTimeMillis())
+ .thenReturn(0).thenReturn(1000) // fail 1 (short)
+ .thenReturn(1000).thenReturn(2000) // fail 2 (short)
+ .thenReturn(2000).thenReturn(10000) // fail 3 (long)
+ .thenReturn(10000).thenReturn(11000) // fail 4 (short)
+ .thenReturn(11000).thenReturn(21000) // success (long)
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(sleeper, times(4)).sleep(anyInt())
+ // Expected sequence of sleeps is 1,2,1,2
+ verify(sleeper, times(2)).sleep(1)
+ verify(sleeper, times(2)).sleep(2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 7e5aaa3f98..bdb2c86d89 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -31,8 +31,8 @@ class ExecutorRunnerTest extends FunSuite {
sparkHome, "appUiUrl")
val appId = "12345-worker321-9876"
val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
- f("ooga"), ExecutorState.RUNNING)
+ f("ooga"), "blah", ExecutorState.RUNNING)
- assert(er.buildCommandSeq().last === appId)
+ assert(er.getCommandSeq.last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
new file mode 100644
index 0000000000..94d88d307a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -0,0 +1,32 @@
+package org.apache.spark.deploy.worker
+
+
+import akka.testkit.TestActorRef
+import org.scalatest.FunSuite
+import akka.remote.DisassociatedEvent
+import akka.actor.{ActorSystem, AddressFromURIString, Props}
+
+class WorkerWatcherSuite extends FunSuite {
+ test("WorkerWatcher shuts down on valid disassociation") {
+ val actorSystem = ActorSystem("test")
+ val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+ val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
+ val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+ val workerWatcher = actorRef.underlyingActor
+ workerWatcher.setTesting(testing = true)
+ actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false))
+ assert(actorRef.underlyingActor.isShutDown)
+ }
+
+ test("WorkerWatcher stays alive on invalid disassociation") {
+ val actorSystem = ActorSystem("test")
+ val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+ val otherAkkaURL = "akka://4.3.2.1/user/OtherActor"
+ val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
+ val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+ val workerWatcher = actorRef.underlyingActor
+ workerWatcher.setTesting(testing = true)
+ actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false))
+ assert(!actorRef.underlyingActor.isShutDown)
+ }
+} \ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 78d2f162b5..7b734c5371 100644
--- a/pom.xml
+++ b/pom.xml
@@ -270,6 +270,18 @@
</exclusions>
</dependency>
<dependency>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-testkit_${scala.binary.version}</artifactId>
+ <version>${akka.version}</version>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil</artifactId>
<version>6.4.4</version>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 051e5105f3..bd5f3f79c9 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -233,6 +233,7 @@ object SparkBuild extends Build {
"org.ow2.asm" % "asm" % "4.0",
"org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
"org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
+ "org.spark-project.akka" %% "akka-testkit" % "2.2.3-shaded-protobuf" % "test",
"net.liftweb" %% "lift-json" % "2.5.1" excludeAll(excludeNetty),
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0",