aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/WebUI.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala34
17 files changed, 233 insertions, 81 deletions
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 0e3750fdde..edc3889c9a 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,7 +23,10 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
+private[spark] class HttpFileServer(
+ securityManager: SecurityManager,
+ requestedPort: Int = 0)
+ extends Logging {
var baseDir : File = null
var fileDir : File = null
@@ -38,7 +41,7 @@ private[spark] class HttpFileServer(securityManager: SecurityManager) extends Lo
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir, securityManager)
+ httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server")
httpServer.start()
serverUri = httpServer.uri
logDebug("HTTP file server started at: " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 7e9b517f90..912558d0ca 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.eclipse.jetty.util.security.{Constraint, Password}
import org.eclipse.jetty.security.authentication.DigestAuthenticator
-import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService}
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
@@ -41,49 +41,69 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
- extends Logging {
+private[spark] class HttpServer(
+ resourceBase: File,
+ securityManager: SecurityManager,
+ requestedPort: Int = 0,
+ serverName: String = "HTTP server")
+ extends Logging {
+
private var server: Server = null
- private var port: Int = -1
+ private var port: Int = requestedPort
def start() {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
logInfo("Starting HTTP Server")
- server = new Server()
- val connector = new SocketConnector
- connector.setMaxIdleTime(60*1000)
- connector.setSoLingerTime(-1)
- connector.setPort(0)
- server.addConnector(connector)
-
- val threadPool = new QueuedThreadPool
- threadPool.setDaemon(true)
- server.setThreadPool(threadPool)
- val resHandler = new ResourceHandler
- resHandler.setResourceBase(resourceBase.getAbsolutePath)
-
- val handlerList = new HandlerList
- handlerList.setHandlers(Array(resHandler, new DefaultHandler))
-
- if (securityManager.isAuthenticationEnabled()) {
- logDebug("HttpServer is using security")
- val sh = setupSecurityHandler(securityManager)
- // make sure we go through security handler to get resources
- sh.setHandler(handlerList)
- server.setHandler(sh)
- } else {
- logDebug("HttpServer is not using security")
- server.setHandler(handlerList)
- }
-
- server.start()
- port = server.getConnectors()(0).getLocalPort()
+ val (actualServer, actualPort) =
+ Utils.startServiceOnPort[Server](requestedPort, doStart, serverName)
+ server = actualServer
+ port = actualPort
}
}
/**
+ * Actually start the HTTP server on the given port.
+ *
+ * Note that this is only best effort in the sense that we may end up binding to a nearby port
+ * in the event of port collision. Return the bound server and the actual port used.
+ */
+ private def doStart(startPort: Int): (Server, Int) = {
+ val server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60 * 1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(startPort)
+ server.addConnector(connector)
+
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val resHandler = new ResourceHandler
+ resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
+ val handlerList = new HandlerList
+ handlerList.setHandlers(Array(resHandler, new DefaultHandler))
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
+ server.start()
+ val actualPort = server.getConnectors()(0).getLocalPort
+
+ (server, actualPort)
+ }
+
+ /**
* Setup Jetty to the HashLoginService using a single user with our
* shared secret. Configure it to use DIGEST-MD5 authentication so that the password
* isn't passed in plaintext.
@@ -134,7 +154,7 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan
if (server == null) {
throw new ServerStateException("Server is not started")
} else {
- return "http://" + Utils.localIpAddress + ":" + port
+ "http://" + Utils.localIpAddress + ":" + port
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index cce7a23d3b..13f0bff7ee 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -323,6 +323,14 @@ private[spark] object SparkConf {
* the scheduler, while the rest of the spark configs can be inherited from the driver later.
*/
def isExecutorStartupConf(name: String): Boolean = {
- isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth")
+ isAkkaConf(name) ||
+ name.startsWith("spark.akka") ||
+ name.startsWith("spark.auth") ||
+ isSparkPortConf(name)
}
+
+ /**
+ * Return whether the given config is a Spark port config.
+ */
+ def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port")
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd8e4ac66d..9d4edeb6d9 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -22,7 +22,6 @@ import java.net.Socket
import scala.collection.JavaConversions._
import scala.collection.mutable
-import scala.concurrent.Await
import scala.util.Properties
import akka.actor._
@@ -151,10 +150,10 @@ object SparkEnv extends Logging {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
securityManager = securityManager)
- // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
- // figure out which port number Akka actually bound to and set spark.driver.port to it.
- if (isDriver && port == 0) {
- conf.set("spark.driver.port", boundPort.toString)
+ // Figure out which port Akka actually bound to in case the original port is 0 or occupied.
+ // This is so that we tell the executors the correct port to connect to.
+ if (isDriver) {
+ conf.set("spark.driver.port", boundPort.toString)
}
// Create an instance of the class named by the given Java system property, or by
@@ -222,7 +221,8 @@ object SparkEnv extends Logging {
val httpFileServer =
if (isDriver) {
- val server = new HttpFileServer(securityManager)
+ val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+ val server = new HttpFileServer(securityManager, fileServerPort)
server.initialize()
conf.set("spark.fileserver.uri", server.serverUri)
server
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 487456467b..942dc7d7ea 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -152,7 +152,8 @@ private[broadcast] object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir, securityManager)
+ val broadcastPort = conf.getInt("spark.broadcast.port", 0)
+ server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 17c507af26..c07003784e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -155,8 +155,6 @@ object Client {
conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING"))
Logger.getRootLogger.setLevel(driverArgs.logLevel)
- // TODO: See if we can initialize akka so return messages are sent back using the same TCP
- // flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
"driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 16aa049337..d86ec1e03e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -28,7 +28,7 @@ import org.apache.spark.util.AkkaUtils
*/
private[spark]
class MasterWebUI(val master: Master, requestedPort: Int)
- extends WebUI(master.securityMgr, requestedPort, master.conf) with Logging {
+ extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging {
val masterActorRef = master.self
val timeout = AkkaUtils.askTimeout(master.conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index a9f531e9e4..47fbda600b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.deploy.worker.ui.WorkerWebUI._
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
import org.apache.spark.util.AkkaUtils
@@ -34,7 +35,7 @@ class WorkerWebUI(
val worker: Worker,
val workDir: File,
port: Option[Int] = None)
- extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf)
+ extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI")
with Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index af736de405..1f46a0f176 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -115,8 +115,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Bootstrap to fetch the driver's Spark properties.
val executorConf = new SparkConf
+ val port = executorConf.getInt("spark.executor.port", 0)
val (fetcher, _) = AkkaUtils.createActorSystem(
- "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf))
+ "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf))
val driver = fetcher.actorSelection(driverUrl)
val timeout = AkkaUtils.askTimeout(executorConf)
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
@@ -126,7 +127,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf))
+ "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 566e8a4aaa..4c00225280 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -38,8 +38,12 @@ import scala.language.postfixOps
import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf,
- securityManager: SecurityManager) extends Logging {
+private[spark] class ConnectionManager(
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ name: String = "Connection manager")
+ extends Logging {
class MessageStatus(
val message: Message,
@@ -105,7 +109,11 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
- serverChannel.socket.bind(new InetSocketAddress(port))
+ private def startService(port: Int): (ServerSocketChannel, Int) = {
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ (serverChannel, serverChannel.socket.getLocalPort)
+ }
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c0a0601794..3876cf43e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -60,10 +60,12 @@ private[spark] class BlockManager(
mapOutputTracker: MapOutputTracker)
extends Logging {
+ private val port = conf.getInt("spark.blockManager.port", 0)
val shuffleBlockManager = new ShuffleBlockManager(this)
val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
- val connectionManager = new ConnectionManager(0, conf, securityManager)
+ val connectionManager =
+ new ConnectionManager(port, conf, securityManager, "Connection manager for block manager")
implicit val futureExecContext = connectionManager.futureExecContext
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index a2535e3c1c..29e9cf9478 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -174,40 +174,32 @@ private[spark] object JettyUtils extends Logging {
hostName: String,
port: Int,
handlers: Seq[ServletContextHandler],
- conf: SparkConf): ServerInfo = {
+ conf: SparkConf,
+ serverName: String = ""): ServerInfo = {
val collection = new ContextHandlerCollection
collection.setHandlers(handlers.toArray)
addFilters(handlers, conf)
- @tailrec
+ // Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
val server = new Server(new InetSocketAddress(hostName, currentPort))
val pool = new QueuedThreadPool
pool.setDaemon(true)
server.setThreadPool(pool)
server.setHandler(collection)
-
- Try {
+ try {
server.start()
- } match {
- case s: Success[_] =>
- (server, server.getConnectors.head.getLocalPort)
- case f: Failure[_] =>
- val nextPort = (currentPort + 1) % 65536
+ (server, server.getConnectors.head.getLocalPort)
+ } catch {
+ case e: Exception =>
server.stop()
pool.stop()
- val msg = s"Failed to create UI on port $currentPort. Trying again on port $nextPort."
- if (f.toString.contains("Address already in use")) {
- logWarning(s"$msg - $f")
- } else {
- logError(msg, f.exception)
- }
- connect(nextPort)
+ throw e
}
}
- val (server, boundPort) = connect(port)
+ val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName)
ServerInfo(server, boundPort, collection)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 097a1b81e1..6c788a37dc 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -36,7 +36,7 @@ private[spark] class SparkUI(
val listenerBus: SparkListenerBus,
var appName: String,
val basePath: String = "")
- extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath)
+ extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
with Logging {
def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName)
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 856273e1d4..5f52f95088 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -39,7 +39,8 @@ private[spark] abstract class WebUI(
securityManager: SecurityManager,
port: Int,
conf: SparkConf,
- basePath: String = "")
+ basePath: String = "",
+ name: String = "")
extends Logging {
protected val tabs = ArrayBuffer[WebUITab]()
@@ -97,7 +98,7 @@ private[spark] abstract class WebUI(
def bind() {
assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className))
try {
- serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf))
+ serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf, name))
logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort))
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index feafd654e9..d6afb73b74 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions.mapAsJavaMap
import scala.concurrent.Await
import scala.concurrent.duration.{Duration, FiniteDuration}
-import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem}
+import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
import akka.pattern.ask
import com.typesafe.config.ConfigFactory
@@ -44,14 +44,28 @@ private[spark] object AkkaUtils extends Logging {
* If indestructible is set to true, the Actor System will continue running in the event
* of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
- def createActorSystem(name: String, host: String, port: Int,
- conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = {
+ def createActorSystem(
+ name: String,
+ host: String,
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager): (ActorSystem, Int) = {
+ val startService: Int => (ActorSystem, Int) = { actualPort =>
+ doCreateActorSystem(name, host, actualPort, conf, securityManager)
+ }
+ Utils.startServiceOnPort(port, startService, name)
+ }
+
+ private def doCreateActorSystem(
+ name: String,
+ host: String,
+ port: Int,
+ conf: SparkConf,
+ securityManager: SecurityManager): (ActorSystem, Int) = {
val akkaThreads = conf.getInt("spark.akka.threads", 4)
val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
-
val akkaTimeout = conf.getInt("spark.akka.timeout", 100)
-
val akkaFrameSize = maxFrameSizeBytes(conf)
val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 30073a8285..c60be4f8a1 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection}
+import java.net._
import java.nio.ByteBuffer
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
@@ -1331,4 +1331,75 @@ private[spark] object Utils extends Logging {
.map { case (k, v) => s"-D$k=$v" }
}
+ /**
+ * Default number of retries in binding to a port.
+ */
+ val portMaxRetries: Int = {
+ if (sys.props.contains("spark.testing")) {
+ // Set a higher number of retries for tests...
+ sys.props.get("spark.ports.maxRetries").map(_.toInt).getOrElse(100)
+ } else {
+ Option(SparkEnv.get)
+ .flatMap(_.conf.getOption("spark.ports.maxRetries"))
+ .map(_.toInt)
+ .getOrElse(16)
+ }
+ }
+
+ /**
+ * Attempt to start a service on the given port, or fail after a number of attempts.
+ * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
+ *
+ * @param startPort The initial port to start the service on.
+ * @param maxRetries Maximum number of retries to attempt.
+ * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example.
+ * @param startService Function to start service on a given port.
+ * This is expected to throw java.net.BindException on port collision.
+ */
+ def startServiceOnPort[T](
+ startPort: Int,
+ startService: Int => (T, Int),
+ serviceName: String = "",
+ maxRetries: Int = portMaxRetries): (T, Int) = {
+ val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
+ for (offset <- 0 to maxRetries) {
+ // Do not increment port if startPort is 0, which is treated as a special port
+ val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536
+ try {
+ val (service, port) = startService(tryPort)
+ logInfo(s"Successfully started service$serviceString on port $port.")
+ return (service, port)
+ } catch {
+ case e: Exception if isBindCollision(e) =>
+ if (offset >= maxRetries) {
+ val exceptionMessage =
+ s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!"
+ val exception = new BindException(exceptionMessage)
+ // restore original stack trace
+ exception.setStackTrace(e.getStackTrace)
+ throw exception
+ }
+ logWarning(s"Service$serviceString could not bind on port $tryPort. " +
+ s"Attempting port ${tryPort + 1}.")
+ }
+ }
+ // Should never happen
+ throw new SparkException(s"Failed to start service$serviceString on port $startPort")
+ }
+
+ /**
+ * Return whether the exception is caused by an address-port collision when binding.
+ */
+ def isBindCollision(exception: Throwable): Boolean = {
+ exception match {
+ case e: BindException =>
+ if (e.getMessage != null && e.getMessage.contains("Address already in use")) {
+ return true
+ }
+ isBindCollision(e.getCause)
+ case e: Exception => isBindCollision(e.getCause)
+ case _ => false
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 1ee936bc78..70d423ba8a 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.util
import scala.util.Random
import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
-import java.net.URI
+import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import com.google.common.base.Charsets
@@ -265,4 +265,36 @@ class UtilsSuite extends FunSuite {
Array("hdfs:/a.jar", "s3:/another.jar"))
}
+ test("isBindCollision") {
+ // Negatives
+ assert(!Utils.isBindCollision(null))
+ assert(!Utils.isBindCollision(new Exception))
+ assert(!Utils.isBindCollision(new Exception(new Exception)))
+ assert(!Utils.isBindCollision(new Exception(new BindException)))
+ assert(!Utils.isBindCollision(new Exception(new BindException("Random message"))))
+
+ // Positives
+ val be = new BindException("Address already in use")
+ val be1 = new Exception(new BindException("Address already in use"))
+ val be2 = new Exception(new Exception(new BindException("Address already in use")))
+ assert(Utils.isBindCollision(be))
+ assert(Utils.isBindCollision(be1))
+ assert(Utils.isBindCollision(be2))
+
+ // Actual bind exception
+ var server1: ServerSocket = null
+ var server2: ServerSocket = null
+ try {
+ server1 = new java.net.ServerSocket(0)
+ server2 = new java.net.ServerSocket(server1.getLocalPort)
+ } catch {
+ case e: Exception =>
+ assert(e.isInstanceOf[java.net.BindException])
+ assert(Utils.isBindCollision(e))
+ } finally {
+ Option(server1).foreach(_.close())
+ Option(server2).foreach(_.close())
+ }
+ }
+
}