aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMarcelo Vanzin <vanzin@cloudera.com>2015-11-23 13:54:19 -0800
committerMarcelo Vanzin <vanzin@cloudera.com>2015-11-23 13:54:19 -0800
commitc2467dadae8ce44010a912ee91c429310f8add65 (patch)
tree13f1eb5a8e2e2348a6a0b1296a720b8ef626eaf9 /core
parent7cfa4c6bc36d97e459d4adee7b03d537d63c337e (diff)
downloadspark-c2467dadae8ce44010a912ee91c429310f8add65.tar.gz
spark-c2467dadae8ce44010a912ee91c429310f8add65.tar.bz2
spark-c2467dadae8ce44010a912ee91c429310f8add65.zip
[SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv.
This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #9530 from vanzin/SPARK-11140.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala138
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala10
9 files changed, 345 insertions, 42 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index af4456c05b..b153a7b08e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
val key = if (!isLocal && scheme == "file") {
- env.httpFileServer.addFile(new File(uri.getPath))
+ env.rpcEnv.fileServer.addFile(new File(uri.getPath))
} else {
schemeCorrectedPath
}
@@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
var key = ""
if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
- key = env.httpFileServer.addJar(new File(path))
+ key = env.rpcEnv.fileServer.addJar(new File(path))
} else {
val uri = new URI(path)
key = uri.getScheme match {
@@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
- env.httpFileServer.addJar(new File(fileName))
+ env.rpcEnv.fileServer.addJar(new File(fileName))
} catch {
case e: Exception =>
// For now just log an error but allow to go through so spark examples work.
@@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
} else {
try {
- env.httpFileServer.addJar(new File(uri.getPath))
+ env.rpcEnv.fileServer.addJar(new File(uri.getPath))
} catch {
case exc: FileNotFoundException =>
logError(s"Jar not found at $path")
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 88df27f733..84230e32a4 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -66,7 +66,6 @@ class SparkEnv (
val blockTransferService: BlockTransferService,
val blockManager: BlockManager,
val securityManager: SecurityManager,
- val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
@@ -91,7 +90,6 @@ class SparkEnv (
if (!isStopped) {
isStopped = true
pythonWorkers.values.foreach(_.stop())
- Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
shuffleManager.stop()
broadcastManager.stop()
@@ -367,17 +365,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val httpFileServer =
- if (isDriver) {
- val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(conf, securityManager, fileServerPort)
- server.initialize()
- conf.set("spark.fileserver.uri", server.serverUri)
- server
- } else {
- null
- }
-
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
@@ -422,7 +409,6 @@ object SparkEnv extends Logging {
blockTransferService,
blockManager,
securityManager,
- httpFileServer,
sparkFilesDir,
metricsSystem,
memoryManager,
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index a560fd10cd..3d7d281b0d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,6 +17,9 @@
package org.apache.spark.rpc
+import java.io.File
+import java.nio.channels.ReadableByteChannel
+
import scala.concurrent.Future
import org.apache.spark.{SecurityManager, SparkConf}
@@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
*/
def deserialize[T](deserializationAction: () => T): T
+
+ /**
+ * Return the instance of the file server used to serve files. This may be `null` if the
+ * RpcEnv is not operating in server mode.
+ */
+ def fileServer: RpcEnvFileServer
+
+ /**
+ * Open a channel to download a file from the given URI. If the URIs returned by the
+ * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to
+ * retrieve the files.
+ *
+ * @param uri URI with location of the file.
+ */
+ def openChannel(uri: String): ReadableByteChannel
+
}
+/**
+ * A server used by the RpcEnv to server files to other processes owned by the application.
+ *
+ * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or
+ * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`.
+ */
+private[spark] trait RpcEnvFileServer {
+
+ /**
+ * Adds a file to be served by this RpcEnv. This is used to serve files from the driver
+ * to executors when they're stored on the driver's local file system.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addFile(file: File): String
+
+ /**
+ * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using
+ * `SparkContext.addJar`.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addJar(file: File): String
+
+}
private[spark] case class RpcEnvConfig(
conf: SparkConf,
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 059a7e10ec..94dbec593c 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rpc.akka
+import java.io.File
+import java.nio.channels.ReadableByteChannel
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
@@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import akka.serialization.JavaSerializer
-import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.rpc._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
@@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
* remove Akka from the dependencies.
*/
private[spark] class AkkaRpcEnv private[akka] (
- val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
+ val actorSystem: ActorSystem,
+ val securityManager: SecurityManager,
+ conf: SparkConf,
+ boundPort: Int)
extends RpcEnv(conf) with Logging {
private val defaultAddress: RpcAddress = {
@@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] (
*/
private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()
+ private val _fileServer = new AkkaFileServer(conf, securityManager)
+
private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
endpointToRef.put(endpoint, endpointRef)
refToEndpoint.put(endpointRef, endpoint)
@@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] (
override def shutdown(): Unit = {
actorSystem.shutdown()
+ _fileServer.shutdown()
}
override def stop(endpoint: RpcEndpointRef): Unit = {
@@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] (
deserializationAction()
}
}
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ throw new UnsupportedOperationException(
+ "AkkaRpcEnv's files should be retrieved using an HTTP client.")
+ }
+
+ override def fileServer: RpcEnvFileServer = _fileServer
+
+}
+
+private[akka] class AkkaFileServer(
+ conf: SparkConf,
+ securityManager: SecurityManager) extends RpcEnvFileServer {
+
+ @volatile private var httpFileServer: HttpFileServer = _
+
+ override def addFile(file: File): String = {
+ getFileServer().addFile(file)
+ }
+
+ override def addJar(file: File): String = {
+ getFileServer().addJar(file)
+ }
+
+ def shutdown(): Unit = {
+ if (httpFileServer != null) {
+ httpFileServer.stop()
+ }
+ }
+
+ private def getFileServer(): HttpFileServer = {
+ if (httpFileServer == null) synchronized {
+ if (httpFileServer == null) {
+ httpFileServer = startFileServer()
+ }
+ }
+ httpFileServer
+ }
+
+ private def startFileServer(): HttpFileServer = {
+ val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
+ server.initialize()
+ server
+ }
+
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
config.name, config.host, config.port, config.conf, config.securityManager)
actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
- new AkkaRpcEnv(actorSystem, config.conf, boundPort)
+ new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 3ce3598680..68701f609f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
+import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.Nullable
@@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- private val transportConf = SparkTransportConf.fromSparkConf(
+ private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
+ private val streamManager = new NettyStreamManager(this)
+
private val transportContext = new TransportContext(transportConf,
- new NettyRpcHandler(dispatcher, this))
+ new NettyRpcHandler(dispatcher, this, streamManager))
- private val clientFactory = {
- val bootstraps: java.util.List[TransportClientBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
- } else {
- java.util.Collections.emptyList[TransportClientBootstrap]
- }
- transportContext.createClientFactory(bootstraps)
+ private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
+ if (securityManager.isAuthenticationEnabled()) {
+ java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ java.util.Collections.emptyList[TransportClientBootstrap]
+ }
}
+ private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
+
+ /**
+ * A separate client factory for file downloads. This avoids using the same RPC handler as
+ * the main RPC context, so that events caused by these clients are kept isolated from the
+ * main RPC traffic.
+ *
+ * It also allows for different configuration of certain properties, such as the number of
+ * connections per peer.
+ */
+ @volatile private var fileDownloadFactory: TransportClientFactory = _
+
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
@@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv(
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
+ if (fileDownloadFactory != null) {
+ fileDownloadFactory.close()
+ }
}
override def deserialize[T](deserializationAction: () => T): T = {
@@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv(
}
}
+ override def fileServer: RpcEnvFileServer = streamManager
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ val parsedUri = new URI(uri)
+ require(parsedUri.getHost() != null, "Host name must be defined.")
+ require(parsedUri.getPort() > 0, "Port must be defined.")
+ require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
+
+ val pipe = Pipe.open()
+ val source = new FileDownloadChannel(pipe.source())
+ try {
+ val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+ val callback = new FileDownloadCallback(pipe.sink(), source, client)
+ client.stream(parsedUri.getPath(), callback)
+ } catch {
+ case e: Exception =>
+ pipe.sink().close()
+ source.close()
+ throw e
+ }
+
+ source
+ }
+
+ private def downloadClient(host: String, port: Int): TransportClient = {
+ if (fileDownloadFactory == null) synchronized {
+ if (fileDownloadFactory == null) {
+ val module = "files"
+ val prefix = "spark.rpc.io."
+ val clone = conf.clone()
+
+ // Copy any RPC configuration that is not overridden in the spark.files namespace.
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith(prefix)) {
+ val opt = key.substring(prefix.length())
+ clone.setIfMissing(s"spark.$module.io.$opt", value)
+ }
+ }
+
+ val ioThreads = clone.getInt("spark.files.io.threads", 1)
+ val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
+ val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
+ fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
+ }
+ }
+ fileDownloadFactory.createClient(host, port)
+ }
+
+ private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
+
+ @volatile private var error: Throwable = _
+
+ def setError(e: Throwable): Unit = error = e
+
+ override def read(dst: ByteBuffer): Int = {
+ if (error != null) {
+ throw error
+ }
+ source.read(dst)
+ }
+
+ override def close(): Unit = source.close()
+
+ override def isOpen(): Boolean = source.isOpen()
+
+ }
+
+ private class FileDownloadCallback(
+ sink: WritableByteChannel,
+ source: FileDownloadChannel,
+ client: TransportClient) extends StreamCallback {
+
+ override def onData(streamId: String, buf: ByteBuffer): Unit = {
+ while (buf.remaining() > 0) {
+ sink.write(buf)
+ }
+ }
+
+ override def onComplete(streamId: String): Unit = {
+ sink.close()
+ }
+
+ override def onFailure(streamId: String, cause: Throwable): Unit = {
+ logError(s"Error downloading stream $streamId.", cause)
+ source.setError(cause)
+ sink.close()
+ }
+
+ }
+
}
private[netty] object NettyRpcEnv extends Logging {
@@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef(
override def toString: String = s"NettyRpcEndpointRef(${_address})"
- def toURI: URI = new URI(s"spark://${_address}")
+ def toURI: URI = new URI(_address.toString)
final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address
@@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable)
* with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
- dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+ dispatcher: Dispatcher,
+ nettyEnv: NettyRpcEnv,
+ streamManager: StreamManager) extends RpcHandler with Logging {
// TODO: Can we add connection callback (channel registered) to the underlying framework?
// A variable to track whether we should dispatch the RemoteProcessConnected message.
@@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler(
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
- override def getStreamManager: StreamManager = new OneForOneStreamManager
+ override def getStreamManager: StreamManager = streamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
@@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler(
override def connectionTerminated(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
clients.remove(client)
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
} else {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
new file mode 100644
index 0000000000..eb1d2604fb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.server.StreamManager
+import org.apache.spark.rpc.RpcEnvFileServer
+
+/**
+ * StreamManager implementation for serving files from a NettyRpcEnv.
+ */
+private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
+ extends StreamManager with RpcEnvFileServer {
+
+ private val files = new ConcurrentHashMap[String, File]()
+ private val jars = new ConcurrentHashMap[String, File]()
+
+ override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def openStream(streamId: String): ManagedBuffer = {
+ val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
+ val file = ftype match {
+ case "files" => files.get(fname)
+ case "jars" => jars.get(fname)
+ case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype")
+ }
+
+ require(file != null, s"File not found: $streamId")
+ new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
+ }
+
+ override def addFile(file: File): String = {
+ require(files.putIfAbsent(file.getName(), file) == null,
+ s"File ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/files/${file.getName()}"
+ }
+
+ override def addJar(file: File): String = {
+ require(jars.putIfAbsent(file.getName(), file) == null,
+ s"JAR ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}"
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1b3acb8ef7..af632349c9 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,6 +21,7 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
+import java.nio.channels.Channels
import java.util.concurrent._
import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection
@@ -535,6 +536,14 @@ private[spark] object Utils extends Logging {
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
Option(uri.getScheme).getOrElse("file") match {
+ case "spark" =>
+ if (SparkEnv.get == null) {
+ throw new IllegalStateException(
+ "Cannot retrieve files with 'spark' scheme without an active SparkEnv.")
+ }
+ val source = SparkEnv.get.rpcEnv.openChannel(url)
+ val is = Channels.newInputStream(source)
+ downloadFile(url, is, targetFile, fileOverwrite)
case "http" | "https" | "ftp" =>
var uc: URLConnection = null
if (securityMgr.isAuthenticationEnabled()) {
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2f55006420..2b664c6313 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rpc
-import java.io.NotSerializableException
+import java.io.{File, NotSerializableException}
+import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -25,10 +27,14 @@ import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
+import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
/**
* Common tests for an RpcEnv implementation.
@@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll(): Unit = {
val conf = new SparkConf()
env = createRpcEnv(conf, "local", 0)
+
+ val sparkEnv = mock(classOf[SparkEnv])
+ when(sparkEnv.rpcEnv).thenReturn(env)
+ SparkEnv.set(sparkEnv)
}
override def afterAll(): Unit = {
if (env != null) {
env.shutdown()
}
+ SparkEnv.set(null)
}
def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
@@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
}
+ test("file server") {
+ val conf = new SparkConf()
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "file")
+ Files.write(UUID.randomUUID().toString(), file, UTF_8)
+ val jar = new File(tempDir, "jar")
+ Files.write(UUID.randomUUID().toString(), jar, UTF_8)
+
+ val fileUri = env.fileServer.addFile(file)
+ val jarUri = env.fileServer.addJar(jar)
+
+ val destDir = Utils.createTempDir()
+ val destFile = new File(destDir, file.getName())
+ val destJar = new File(destDir, jar.getName())
+
+ val sm = new SecurityManager(conf)
+ val hc = SparkHadoopUtil.get.conf
+ Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
+ Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
+
+ assert(Files.equal(file, destFile))
+ assert(Files.equal(jar, destJar))
+ }
+
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80c98..ccca795683 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ val sm = mock(classOf[StreamManager])
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
test("connectionTerminated") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))