aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org')
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala10
2 files changed, 43 insertions, 6 deletions
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2f55006420..2b664c6313 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rpc
-import java.io.NotSerializableException
+import java.io.{File, NotSerializableException}
+import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -25,10 +27,14 @@ import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
+import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
/**
* Common tests for an RpcEnv implementation.
@@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll(): Unit = {
val conf = new SparkConf()
env = createRpcEnv(conf, "local", 0)
+
+ val sparkEnv = mock(classOf[SparkEnv])
+ when(sparkEnv.rpcEnv).thenReturn(env)
+ SparkEnv.set(sparkEnv)
}
override def afterAll(): Unit = {
if (env != null) {
env.shutdown()
}
+ SparkEnv.set(null)
}
def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
@@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
}
+ test("file server") {
+ val conf = new SparkConf()
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "file")
+ Files.write(UUID.randomUUID().toString(), file, UTF_8)
+ val jar = new File(tempDir, "jar")
+ Files.write(UUID.randomUUID().toString(), jar, UTF_8)
+
+ val fileUri = env.fileServer.addFile(file)
+ val jarUri = env.fileServer.addJar(jar)
+
+ val destDir = Utils.createTempDir()
+ val destFile = new File(destDir, file.getName())
+ val destJar = new File(destDir, jar.getName())
+
+ val sm = new SecurityManager(conf)
+ val hc = SparkHadoopUtil.get.conf
+ Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
+ Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
+
+ assert(Files.equal(file, destFile))
+ assert(Files.equal(jar, destJar))
+ }
+
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80c98..ccca795683 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ val sm = mock(classOf[StreamManager])
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
test("connectionTerminated") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))