diff options
Diffstat (limited to 'core/src/main')
3 files changed, 120 insertions, 87 deletions
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 5f14102c3c..29163e7f30 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging * * @param enabled enables or disables SSL; if it is set to false, the rest of the * settings are disregarded + * @param port the port where to bind the SSL server; if not defined, it will be + * based on the non-SSL port for the same service. * @param keyStore a path to the key-store file * @param keyStorePassword a password to access the key-store file * @param keyPassword a password to access the private key in the key-store @@ -47,6 +49,7 @@ import org.apache.spark.internal.Logging */ private[spark] case class SSLOptions( enabled: Boolean = false, + port: Option[Int] = None, keyStore: Option[File] = None, keyStorePassword: Option[String] = None, keyPassword: Option[String] = None, @@ -164,6 +167,11 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + val port = conf.getOption(s"$ns.port").map(_.toInt) + port.foreach { p => + require(p >= 0, "Port number must be a non-negative value.") + } + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) @@ -198,6 +206,7 @@ private[spark] object SSLOptions extends Logging { new SSLOptions( enabled, + port, keyStore, keyStorePassword, keyPassword, diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index f713619cd7..7909821db9 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -27,7 +27,7 @@ import scala.xml.Node import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} +import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.servlets.gzip.GzipHandler @@ -279,109 +279,125 @@ private[spark] object JettyUtils extends Logging { addFilters(handlers, conf) - val gzipHandlers = handlers.map { h => - h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) - - val gzipHandler = new GzipHandler - gzipHandler.setHandler(h) - gzipHandler + // Start the server first, with no connectors. + val pool = new QueuedThreadPool + if (serverName.nonEmpty) { + pool.setName(serverName) } + pool.setDaemon(true) - // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): ((Server, Option[Int]), Int) = { - val pool = new QueuedThreadPool - if (serverName.nonEmpty) { - pool.setName(serverName) - } - pool.setDaemon(true) - - val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector]() - val collection = new ContextHandlerCollection - - // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new ServerConnector( - server, - null, - // Call this full constructor to set this, which forces daemon threads: - new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), - null, - -1, - -1, - new HttpConnectionFactory()) - httpConnector.setPort(currentPort) - connectors += httpConnector - - val httpsConnector = sslOptions.createJettySslContextFactory() match { - case Some(factory) => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) - connector.setName(SPARK_CONNECTOR_NAME) - connectors += connector - - // redirect the HTTP requests to HTTPS port - httpConnector.setName(REDIRECT_CONNECTOR_NAME) - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) - Some(connector) + val server = new Server(pool) - case None => - // No SSL, so the HTTP connector becomes the official one where all contexts bind. - httpConnector.setName(SPARK_CONNECTOR_NAME) - None - } + val errorHandler = new ErrorHandler() + errorHandler.setShowStacks(true) + errorHandler.setServer(server) + server.addBean(errorHandler) + + val collection = new ContextHandlerCollection + server.setHandler(collection) + + // Executor used to create daemon threads for the Jetty connectors. + val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true) + + try { + server.start() // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 - connectors.foreach { connector => + + def newConnector( + connectionFactories: Array[ConnectionFactory], + port: Int): (ServerConnector, Int) = { + val connector = new ServerConnector( + server, + null, + serverExecutor, + null, + -1, + -1, + connectionFactories: _*) + connector.setPort(port) + connector.start() + // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) connector.setHost(hostName) // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 + + (connector, connector.getLocalPort()) } - pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - val errorHandler = new ErrorHandler() - errorHandler.setShowStacks(true) - errorHandler.setServer(server) - server.addBean(errorHandler) - - gzipHandlers.foreach(collection.addHandler) - server.setHandler(collection) - - server.setConnectors(connectors.toArray) - try { - server.start() - ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) - } catch { - case e: Exception => - server.stop() - pool.stop() - throw e + // If SSL is configured, create the secure connector first. + val securePort = sslOptions.createJettySslContextFactory().map { factory => + val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) + val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName + val connectionFactories = AbstractConnectionFactory.getFactories(factory, + new HttpConnectionFactory()) + + def sslConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(connectionFactories, currentPort) + } + + val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort, + sslConnect, conf, secureServerName) + connector.setName(SPARK_CONNECTOR_NAME) + server.addConnector(connector) + boundPort } - } - val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, - serverName) - ServerInfo(server, boundPort, securePort, - server.getHandler().asInstanceOf[ContextHandlerCollection]) + // Bind the HTTP port. + def httpConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(Array(new HttpConnectionFactory()), currentPort) + } + + val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, + conf, serverName) + + // If SSL is configured, then configure redirection in the HTTP connector. + securePort match { + case Some(p) => + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + val redirector = createRedirectHttpsHandler(p, "https") + collection.addHandler(redirector) + redirector.start() + + case None => + httpConnector.setName(SPARK_CONNECTOR_NAME) + } + + server.addConnector(httpConnector) + + // Add all the known handlers now that connectors are configured. + handlers.foreach { h => + h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler() + gzipHandler.setHandler(h) + collection.addHandler(gzipHandler) + gzipHandler.start() + } + + pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) + ServerInfo(server, httpPort, securePort, collection) + } catch { + case e: Exception => + server.stop() + if (serverExecutor.isStarted()) { + serverExecutor.stop() + } + if (pool.isStarted()) { + pool.stop() + } + throw e + } } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") - redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) + redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -394,8 +410,7 @@ private[spark] object JettyUtils extends Logging { val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, baseRequest.getRequestURI, baseRequest.getQueryString) response.setContentLength(0) - response.encodeRedirectURL(httpsURI) - response.sendRedirect(httpsURI) + response.sendRedirect(response.encodeRedirectURL(httpsURI)) baseRequest.setHandled(true) } }) @@ -456,6 +471,8 @@ private[spark] object JettyUtils extends Logging { new URI(scheme, authority, path, query, null).toString } + def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray + } private[spark] case class ServerInfo( @@ -465,7 +482,7 @@ private[spark] case class ServerInfo( private val rootHandler: ContextHandlerCollection) { def addHandler(handler: ContextHandler): Unit = { - handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME)) + handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c1d331b9a..c225e1a0cc 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2203,6 +2203,14 @@ private[spark] object Utils extends Logging { } /** + * Returns the user port to try when trying to bind a service. Handles wrapping and skipping + * privileged ports. + */ + def userPort(base: Int, offset: Int): Int = { + (base + offset - 1024) % (65536 - 1024) + 1024 + } + + /** * Attempt to start a service on the given port, or fail after a number of attempts. * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). * @@ -2229,8 +2237,7 @@ private[spark] object Utils extends Logging { val tryPort = if (startPort == 0) { startPort } else { - // If the new port wraps around, do not try a privilege port - ((startPort + offset - 1024) % (65536 - 1024)) + 1024 + userPort(startPort, offset) } try { val (service, port) = startService(tryPort) |