aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala17
1 files changed, 15 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 4089c3e771..20a9faa178 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -27,6 +27,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.Utils
/**
@@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
- private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val blockHandler = newShuffleBlockHandler(transportConf)
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
private var server: TransportServer = _
+ /** Create a new shuffle block handler. Factored out for subclasses to override. */
+ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = {
+ new ExternalShuffleBlockHandler(conf)
+ }
+
/** Starts the external shuffle service if the user has configured us to. */
def startIfEnabled() {
if (enabled) {
@@ -93,6 +99,13 @@ object ExternalShuffleService extends Logging {
private val barrier = new CountDownLatch(1)
def main(args: Array[String]): Unit = {
+ main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm))
+ }
+
+ /** A helper main method that allows the caller to call this with a custom shuffle service. */
+ private[spark] def main(
+ args: Array[String],
+ newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = {
val sparkConf = new SparkConf
Utils.loadDefaultSparkProperties(sparkConf)
val securityManager = new SecurityManager(sparkConf)
@@ -100,7 +113,7 @@ object ExternalShuffleService extends Logging {
// we override this value since this service is started from the command line
// and we assume the user really wants it to be running
sparkConf.set("spark.shuffle.service.enabled", "true")
- server = new ExternalShuffleService(sparkConf, securityManager)
+ server = newShuffleService(sparkConf, securityManager)
server.start()
installShutdownHook()