aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala')
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala10
1 files changed, 6 insertions, 4 deletions
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80c98..ccca795683 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ val sm = mock(classOf[StreamManager])
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
test("connectionTerminated") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))