aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala138
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala10
9 files changed, 345 insertions, 42 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index af4456c05b..b153a7b08e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
val key = if (!isLocal && scheme == "file") {
- env.httpFileServer.addFile(new File(uri.getPath))
+ env.rpcEnv.fileServer.addFile(new File(uri.getPath))
} else {
schemeCorrectedPath
}
@@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
var key = ""
if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
- key = env.httpFileServer.addJar(new File(path))
+ key = env.rpcEnv.fileServer.addJar(new File(path))
} else {
val uri = new URI(path)
key = uri.getScheme match {
@@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
- env.httpFileServer.addJar(new File(fileName))
+ env.rpcEnv.fileServer.addJar(new File(fileName))
} catch {
case e: Exception =>
// For now just log an error but allow to go through so spark examples work.
@@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
} else {
try {
- env.httpFileServer.addJar(new File(uri.getPath))
+ env.rpcEnv.fileServer.addJar(new File(uri.getPath))
} catch {
case exc: FileNotFoundException =>
logError(s"Jar not found at $path")
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 88df27f733..84230e32a4 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -66,7 +66,6 @@ class SparkEnv (
val blockTransferService: BlockTransferService,
val blockManager: BlockManager,
val securityManager: SecurityManager,
- val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
@@ -91,7 +90,6 @@ class SparkEnv (
if (!isStopped) {
isStopped = true
pythonWorkers.values.foreach(_.stop())
- Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
shuffleManager.stop()
broadcastManager.stop()
@@ -367,17 +365,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val httpFileServer =
- if (isDriver) {
- val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(conf, securityManager, fileServerPort)
- server.initialize()
- conf.set("spark.fileserver.uri", server.serverUri)
- server
- } else {
- null
- }
-
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
@@ -422,7 +409,6 @@ object SparkEnv extends Logging {
blockTransferService,
blockManager,
securityManager,
- httpFileServer,
sparkFilesDir,
metricsSystem,
memoryManager,
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index a560fd10cd..3d7d281b0d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,6 +17,9 @@
package org.apache.spark.rpc
+import java.io.File
+import java.nio.channels.ReadableByteChannel
+
import scala.concurrent.Future
import org.apache.spark.{SecurityManager, SparkConf}
@@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
*/
def deserialize[T](deserializationAction: () => T): T
+
+ /**
+ * Return the instance of the file server used to serve files. This may be `null` if the
+ * RpcEnv is not operating in server mode.
+ */
+ def fileServer: RpcEnvFileServer
+
+ /**
+ * Open a channel to download a file from the given URI. If the URIs returned by the
+ * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to
+ * retrieve the files.
+ *
+ * @param uri URI with location of the file.
+ */
+ def openChannel(uri: String): ReadableByteChannel
+
}
+/**
+ * A server used by the RpcEnv to server files to other processes owned by the application.
+ *
+ * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or
+ * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`.
+ */
+private[spark] trait RpcEnvFileServer {
+
+ /**
+ * Adds a file to be served by this RpcEnv. This is used to serve files from the driver
+ * to executors when they're stored on the driver's local file system.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addFile(file: File): String
+
+ /**
+ * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using
+ * `SparkContext.addJar`.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addJar(file: File): String
+
+}
private[spark] case class RpcEnvConfig(
conf: SparkConf,
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 059a7e10ec..94dbec593c 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rpc.akka
+import java.io.File
+import java.nio.channels.ReadableByteChannel
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
@@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import akka.serialization.JavaSerializer
-import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.rpc._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
@@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
* remove Akka from the dependencies.
*/
private[spark] class AkkaRpcEnv private[akka] (
- val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
+ val actorSystem: ActorSystem,
+ val securityManager: SecurityManager,
+ conf: SparkConf,
+ boundPort: Int)
extends RpcEnv(conf) with Logging {
private val defaultAddress: RpcAddress = {
@@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] (
*/
private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()
+ private val _fileServer = new AkkaFileServer(conf, securityManager)
+
private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
endpointToRef.put(endpoint, endpointRef)
refToEndpoint.put(endpointRef, endpoint)
@@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] (
override def shutdown(): Unit = {
actorSystem.shutdown()
+ _fileServer.shutdown()
}
override def stop(endpoint: RpcEndpointRef): Unit = {
@@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] (
deserializationAction()
}
}
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ throw new UnsupportedOperationException(
+ "AkkaRpcEnv's files should be retrieved using an HTTP client.")
+ }
+
+ override def fileServer: RpcEnvFileServer = _fileServer
+
+}
+
+private[akka] class AkkaFileServer(
+ conf: SparkConf,
+ securityManager: SecurityManager) extends RpcEnvFileServer {
+
+ @volatile private var httpFileServer: HttpFileServer = _
+
+ override def addFile(file: File): String = {
+ getFileServer().addFile(file)
+ }
+
+ override def addJar(file: File): String = {
+ getFileServer().addJar(file)
+ }
+
+ def shutdown(): Unit = {
+ if (httpFileServer != null) {
+ httpFileServer.stop()
+ }
+ }
+
+ private def getFileServer(): HttpFileServer = {
+ if (httpFileServer == null) synchronized {
+ if (httpFileServer == null) {
+ httpFileServer = startFileServer()
+ }
+ }
+ httpFileServer
+ }
+
+ private def startFileServer(): HttpFileServer = {
+ val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
+ server.initialize()
+ server
+ }
+
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
config.name, config.host, config.port, config.conf, config.securityManager)
actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
- new AkkaRpcEnv(actorSystem, config.conf, boundPort)
+ new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 3ce3598680..68701f609f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
+import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.Nullable
@@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- private val transportConf = SparkTransportConf.fromSparkConf(
+ private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
+ private val streamManager = new NettyStreamManager(this)
+
private val transportContext = new TransportContext(transportConf,
- new NettyRpcHandler(dispatcher, this))
+ new NettyRpcHandler(dispatcher, this, streamManager))
- private val clientFactory = {
- val bootstraps: java.util.List[TransportClientBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
- } else {
- java.util.Collections.emptyList[TransportClientBootstrap]
- }
- transportContext.createClientFactory(bootstraps)
+ private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
+ if (securityManager.isAuthenticationEnabled()) {
+ java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ java.util.Collections.emptyList[TransportClientBootstrap]
+ }
}
+ private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
+
+ /**
+ * A separate client factory for file downloads. This avoids using the same RPC handler as
+ * the main RPC context, so that events caused by these clients are kept isolated from the
+ * main RPC traffic.
+ *
+ * It also allows for different configuration of certain properties, such as the number of
+ * connections per peer.
+ */
+ @volatile private var fileDownloadFactory: TransportClientFactory = _
+
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
@@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv(
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
+ if (fileDownloadFactory != null) {
+ fileDownloadFactory.close()
+ }
}
override def deserialize[T](deserializationAction: () => T): T = {
@@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv(
}
}
+ override def fileServer: RpcEnvFileServer = streamManager
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ val parsedUri = new URI(uri)
+ require(parsedUri.getHost() != null, "Host name must be defined.")
+ require(parsedUri.getPort() > 0, "Port must be defined.")
+ require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
+
+ val pipe = Pipe.open()
+ val source = new FileDownloadChannel(pipe.source())
+ try {
+ val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+ val callback = new FileDownloadCallback(pipe.sink(), source, client)
+ client.stream(parsedUri.getPath(), callback)
+ } catch {
+ case e: Exception =>
+ pipe.sink().close()
+ source.close()
+ throw e
+ }
+
+ source
+ }
+
+ private def downloadClient(host: String, port: Int): TransportClient = {
+ if (fileDownloadFactory == null) synchronized {
+ if (fileDownloadFactory == null) {
+ val module = "files"
+ val prefix = "spark.rpc.io."
+ val clone = conf.clone()
+
+ // Copy any RPC configuration that is not overridden in the spark.files namespace.
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith(prefix)) {
+ val opt = key.substring(prefix.length())
+ clone.setIfMissing(s"spark.$module.io.$opt", value)
+ }
+ }
+
+ val ioThreads = clone.getInt("spark.files.io.threads", 1)
+ val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
+ val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
+ fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
+ }
+ }
+ fileDownloadFactory.createClient(host, port)
+ }
+
+ private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
+
+ @volatile private var error: Throwable = _
+
+ def setError(e: Throwable): Unit = error = e
+
+ override def read(dst: ByteBuffer): Int = {
+ if (error != null) {
+ throw error
+ }
+ source.read(dst)
+ }
+
+ override def close(): Unit = source.close()
+
+ override def isOpen(): Boolean = source.isOpen()
+
+ }
+
+ private class FileDownloadCallback(
+ sink: WritableByteChannel,
+ source: FileDownloadChannel,
+ client: TransportClient) extends StreamCallback {
+
+ override def onData(streamId: String, buf: ByteBuffer): Unit = {
+ while (buf.remaining() > 0) {
+ sink.write(buf)
+ }
+ }
+
+ override def onComplete(streamId: String): Unit = {
+ sink.close()
+ }
+
+ override def onFailure(streamId: String, cause: Throwable): Unit = {
+ logError(s"Error downloading stream $streamId.", cause)
+ source.setError(cause)
+ sink.close()
+ }
+
+ }
+
}
private[netty] object NettyRpcEnv extends Logging {
@@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef(
override def toString: String = s"NettyRpcEndpointRef(${_address})"
- def toURI: URI = new URI(s"spark://${_address}")
+ def toURI: URI = new URI(_address.toString)
final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address
@@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable)
* with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
- dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+ dispatcher: Dispatcher,
+ nettyEnv: NettyRpcEnv,
+ streamManager: StreamManager) extends RpcHandler with Logging {
// TODO: Can we add connection callback (channel registered) to the underlying framework?
// A variable to track whether we should dispatch the RemoteProcessConnected message.
@@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler(
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
- override def getStreamManager: StreamManager = new OneForOneStreamManager
+ override def getStreamManager: StreamManager = streamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
@@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler(
override def connectionTerminated(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
clients.remove(client)
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
} else {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
new file mode 100644
index 0000000000..eb1d2604fb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.server.StreamManager
+import org.apache.spark.rpc.RpcEnvFileServer
+
+/**
+ * StreamManager implementation for serving files from a NettyRpcEnv.
+ */
+private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
+ extends StreamManager with RpcEnvFileServer {
+
+ private val files = new ConcurrentHashMap[String, File]()
+ private val jars = new ConcurrentHashMap[String, File]()
+
+ override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def openStream(streamId: String): ManagedBuffer = {
+ val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
+ val file = ftype match {
+ case "files" => files.get(fname)
+ case "jars" => jars.get(fname)
+ case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype")
+ }
+
+ require(file != null, s"File not found: $streamId")
+ new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
+ }
+
+ override def addFile(file: File): String = {
+ require(files.putIfAbsent(file.getName(), file) == null,
+ s"File ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/files/${file.getName()}"
+ }
+
+ override def addJar(file: File): String = {
+ require(jars.putIfAbsent(file.getName(), file) == null,
+ s"JAR ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}"
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1b3acb8ef7..af632349c9 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,6 +21,7 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
+import java.nio.channels.Channels
import java.util.concurrent._
import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection
@@ -535,6 +536,14 @@ private[spark] object Utils extends Logging {
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
Option(uri.getScheme).getOrElse("file") match {
+ case "spark" =>
+ if (SparkEnv.get == null) {
+ throw new IllegalStateException(
+ "Cannot retrieve files with 'spark' scheme without an active SparkEnv.")
+ }
+ val source = SparkEnv.get.rpcEnv.openChannel(url)
+ val is = Channels.newInputStream(source)
+ downloadFile(url, is, targetFile, fileOverwrite)
case "http" | "https" | "ftp" =>
var uc: URLConnection = null
if (securityMgr.isAuthenticationEnabled()) {
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2f55006420..2b664c6313 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rpc
-import java.io.NotSerializableException
+import java.io.{File, NotSerializableException}
+import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -25,10 +27,14 @@ import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
+import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
/**
* Common tests for an RpcEnv implementation.
@@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll(): Unit = {
val conf = new SparkConf()
env = createRpcEnv(conf, "local", 0)
+
+ val sparkEnv = mock(classOf[SparkEnv])
+ when(sparkEnv.rpcEnv).thenReturn(env)
+ SparkEnv.set(sparkEnv)
}
override def afterAll(): Unit = {
if (env != null) {
env.shutdown()
}
+ SparkEnv.set(null)
}
def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
@@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
}
+ test("file server") {
+ val conf = new SparkConf()
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "file")
+ Files.write(UUID.randomUUID().toString(), file, UTF_8)
+ val jar = new File(tempDir, "jar")
+ Files.write(UUID.randomUUID().toString(), jar, UTF_8)
+
+ val fileUri = env.fileServer.addFile(file)
+ val jarUri = env.fileServer.addJar(jar)
+
+ val destDir = Utils.createTempDir()
+ val destFile = new File(destDir, file.getName())
+ val destJar = new File(destDir, jar.getName())
+
+ val sm = new SecurityManager(conf)
+ val hc = SparkHadoopUtil.get.conf
+ Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
+ Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
+
+ assert(Files.equal(file, destFile))
+ assert(Files.equal(jar, destJar))
+ }
+
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80c98..ccca795683 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ val sm = mock(classOf[StreamManager])
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
test("connectionTerminated") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))