aboutsummaryrefslogtreecommitdiff
path: root/network/common/src/test/java/org
diff options
context:
space:
mode:
Diffstat (limited to 'network/common/src/test/java/org')
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java244
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java127
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java288
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java215
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java349
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java109
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestUtils.java30
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java214
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java146
-rw-r--r--network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java157
-rw-r--r--network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java476
-rw-r--r--network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java50
-rw-r--r--network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java258
13 files changed, 0 insertions, 2663 deletions
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
deleted file mode 100644
index 70c849d60e..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ /dev/null
@@ -1,244 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.io.File;
-import java.io.RandomAccessFile;
-import java.nio.ByteBuffer;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Random;
-import java.util.Set;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
-
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
-import com.google.common.io.Closeables;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-import static org.junit.Assert.*;
-
-import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NioManagedBuffer;
-import org.apache.spark.network.client.ChunkReceivedCallback;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-
-public class ChunkFetchIntegrationSuite {
- static final long STREAM_ID = 1;
- static final int BUFFER_CHUNK_INDEX = 0;
- static final int FILE_CHUNK_INDEX = 1;
-
- static TransportServer server;
- static TransportClientFactory clientFactory;
- static StreamManager streamManager;
- static File testFile;
-
- static ManagedBuffer bufferChunk;
- static ManagedBuffer fileChunk;
-
- private TransportConf transportConf;
-
- @BeforeClass
- public static void setUp() throws Exception {
- int bufSize = 100000;
- final ByteBuffer buf = ByteBuffer.allocate(bufSize);
- for (int i = 0; i < bufSize; i ++) {
- buf.put((byte) i);
- }
- buf.flip();
- bufferChunk = new NioManagedBuffer(buf);
-
- testFile = File.createTempFile("shuffle-test-file", "txt");
- testFile.deleteOnExit();
- RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
- boolean shouldSuppressIOException = true;
- try {
- byte[] fileContent = new byte[1024];
- new Random().nextBytes(fileContent);
- fp.write(fileContent);
- shouldSuppressIOException = false;
- } finally {
- Closeables.close(fp, shouldSuppressIOException);
- }
-
- final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
- fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
-
- streamManager = new StreamManager() {
- @Override
- public ManagedBuffer getChunk(long streamId, int chunkIndex) {
- assertEquals(STREAM_ID, streamId);
- if (chunkIndex == BUFFER_CHUNK_INDEX) {
- return new NioManagedBuffer(buf);
- } else if (chunkIndex == FILE_CHUNK_INDEX) {
- return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);
- } else {
- throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
- }
- }
- };
- RpcHandler handler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public StreamManager getStreamManager() {
- return streamManager;
- }
- };
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
- }
-
- @AfterClass
- public static void tearDown() {
- bufferChunk.release();
- server.close();
- clientFactory.close();
- testFile.delete();
- }
-
- class FetchResult {
- public Set<Integer> successChunks;
- public Set<Integer> failedChunks;
- public List<ManagedBuffer> buffers;
-
- public void releaseBuffers() {
- for (ManagedBuffer buffer : buffers) {
- buffer.release();
- }
- }
- }
-
- private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- final Semaphore sem = new Semaphore(0);
-
- final FetchResult res = new FetchResult();
- res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
- res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
- res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
-
- ChunkReceivedCallback callback = new ChunkReceivedCallback() {
- @Override
- public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
- buffer.retain();
- res.successChunks.add(chunkIndex);
- res.buffers.add(buffer);
- sem.release();
- }
-
- @Override
- public void onFailure(int chunkIndex, Throwable e) {
- res.failedChunks.add(chunkIndex);
- sem.release();
- }
- };
-
- for (int chunkIndex : chunkIndices) {
- client.fetchChunk(STREAM_ID, chunkIndex, callback);
- }
- if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server");
- }
- client.close();
- return res;
- }
-
- @Test
- public void fetchBufferChunk() throws Exception {
- FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
- assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
- assertTrue(res.failedChunks.isEmpty());
- assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
- res.releaseBuffers();
- }
-
- @Test
- public void fetchFileChunk() throws Exception {
- FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
- assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
- assertTrue(res.failedChunks.isEmpty());
- assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
- res.releaseBuffers();
- }
-
- @Test
- public void fetchNonExistentChunk() throws Exception {
- FetchResult res = fetchChunks(Lists.newArrayList(12345));
- assertTrue(res.successChunks.isEmpty());
- assertEquals(res.failedChunks, Sets.newHashSet(12345));
- assertTrue(res.buffers.isEmpty());
- }
-
- @Test
- public void fetchBothChunks() throws Exception {
- FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
- assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
- assertTrue(res.failedChunks.isEmpty());
- assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
- res.releaseBuffers();
- }
-
- @Test
- public void fetchChunkAndNonExistent() throws Exception {
- FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
- assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
- assertEquals(res.failedChunks, Sets.newHashSet(12345));
- assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
- res.releaseBuffers();
- }
-
- private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
- throws Exception {
- assertEquals(list0.size(), list1.size());
- for (int i = 0; i < list0.size(); i ++) {
- assertBuffersEqual(list0.get(i), list1.get(i));
- }
- }
-
- private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
- ByteBuffer nio0 = buffer0.nioByteBuffer();
- ByteBuffer nio1 = buffer1.nioByteBuffer();
-
- int len = nio0.remaining();
- assertEquals(nio0.remaining(), nio1.remaining());
- for (int i = 0; i < len; i ++) {
- assertEquals(nio0.get(), nio1.get());
- }
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
deleted file mode 100644
index 6c8dd742f4..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.util.List;
-
-import com.google.common.primitives.Ints;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.FileRegion;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.handler.codec.MessageToMessageEncoder;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-import org.apache.spark.network.protocol.ChunkFetchFailure;
-import org.apache.spark.network.protocol.ChunkFetchRequest;
-import org.apache.spark.network.protocol.ChunkFetchSuccess;
-import org.apache.spark.network.protocol.Message;
-import org.apache.spark.network.protocol.MessageDecoder;
-import org.apache.spark.network.protocol.MessageEncoder;
-import org.apache.spark.network.protocol.OneWayMessage;
-import org.apache.spark.network.protocol.RpcFailure;
-import org.apache.spark.network.protocol.RpcRequest;
-import org.apache.spark.network.protocol.RpcResponse;
-import org.apache.spark.network.protocol.StreamChunkId;
-import org.apache.spark.network.protocol.StreamFailure;
-import org.apache.spark.network.protocol.StreamRequest;
-import org.apache.spark.network.protocol.StreamResponse;
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.NettyUtils;
-
-public class ProtocolSuite {
- private void testServerToClient(Message msg) {
- EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
- new MessageEncoder());
- serverChannel.writeOutbound(msg);
-
- EmbeddedChannel clientChannel = new EmbeddedChannel(
- NettyUtils.createFrameDecoder(), new MessageDecoder());
-
- while (!serverChannel.outboundMessages().isEmpty()) {
- clientChannel.writeInbound(serverChannel.readOutbound());
- }
-
- assertEquals(1, clientChannel.inboundMessages().size());
- assertEquals(msg, clientChannel.readInbound());
- }
-
- private void testClientToServer(Message msg) {
- EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
- new MessageEncoder());
- clientChannel.writeOutbound(msg);
-
- EmbeddedChannel serverChannel = new EmbeddedChannel(
- NettyUtils.createFrameDecoder(), new MessageDecoder());
-
- while (!clientChannel.outboundMessages().isEmpty()) {
- serverChannel.writeInbound(clientChannel.readOutbound());
- }
-
- assertEquals(1, serverChannel.inboundMessages().size());
- assertEquals(msg, serverChannel.readInbound());
- }
-
- @Test
- public void requests() {
- testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
- testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
- testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
- testClientToServer(new StreamRequest("abcde"));
- testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
- }
-
- @Test
- public void responses() {
- testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
- testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
- testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
- testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
- testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
- testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
- testServerToClient(new RpcFailure(0, "this is an error"));
- testServerToClient(new RpcFailure(0, ""));
- // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the
- // channel and cannot be tested like this.
- testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0)));
- testServerToClient(new StreamFailure("anId", "this is an error"));
- }
-
- /**
- * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
- * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
- * understand what MessageWithHeader actually contains.
- */
- private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
-
- @Override
- public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
- throws Exception {
-
- ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
- while (in.transfered() < in.count()) {
- in.transferTo(channel, in.transfered());
- }
- out.add(Unpooled.wrappedBuffer(channel.getData()));
- }
-
- }
-
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
deleted file mode 100644
index f9b5bf96d6..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ /dev/null
@@ -1,288 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import com.google.common.collect.Maps;
-import com.google.common.util.concurrent.Uninterruptibles;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NioManagedBuffer;
-import org.apache.spark.network.client.ChunkReceivedCallback;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.MapConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-import org.junit.*;
-import static org.junit.Assert.*;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
-
-/**
- * Suite which ensures that requests that go without a response for the network timeout period are
- * failed, and the connection closed.
- *
- * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests,
- * to ensure stability in different test environments.
- */
-public class RequestTimeoutIntegrationSuite {
-
- private TransportServer server;
- private TransportClientFactory clientFactory;
-
- private StreamManager defaultManager;
- private TransportConf conf;
-
- // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
- private final int FOREVER = 60 * 1000;
-
- @Before
- public void setUp() throws Exception {
- Map<String, String> configMap = Maps.newHashMap();
- configMap.put("spark.shuffle.io.connectionTimeout", "2s");
- conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
-
- defaultManager = new StreamManager() {
- @Override
- public ManagedBuffer getChunk(long streamId, int chunkIndex) {
- throw new UnsupportedOperationException();
- }
- };
- }
-
- @After
- public void tearDown() {
- if (server != null) {
- server.close();
- }
- if (clientFactory != null) {
- clientFactory.close();
- }
- }
-
- // Basic suite: First request completes quickly, and second waits for longer than network timeout.
- @Test
- public void timeoutInactiveRequests() throws Exception {
- final Semaphore semaphore = new Semaphore(1);
- final int responseSize = 16;
- RpcHandler handler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- try {
- semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(ByteBuffer.allocate(responseSize));
- } catch (InterruptedException e) {
- // do nothing
- }
- }
-
- @Override
- public StreamManager getStreamManager() {
- return defaultManager;
- }
- };
-
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-
- // First completes quickly (semaphore starts at 1).
- TestCallback callback0 = new TestCallback();
- synchronized (callback0) {
- client.sendRpc(ByteBuffer.allocate(0), callback0);
- callback0.wait(FOREVER);
- assertEquals(responseSize, callback0.successLength);
- }
-
- // Second times out after 2 seconds, with slack. Must be IOException.
- TestCallback callback1 = new TestCallback();
- synchronized (callback1) {
- client.sendRpc(ByteBuffer.allocate(0), callback1);
- callback1.wait(4 * 1000);
- assert (callback1.failure != null);
- assert (callback1.failure instanceof IOException);
- }
- semaphore.release();
- }
-
- // A timeout will cause the connection to be closed, invalidating the current TransportClient.
- // It should be the case that requesting a client from the factory produces a new, valid one.
- @Test
- public void timeoutCleanlyClosesClient() throws Exception {
- final Semaphore semaphore = new Semaphore(0);
- final int responseSize = 16;
- RpcHandler handler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- try {
- semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
- callback.onSuccess(ByteBuffer.allocate(responseSize));
- } catch (InterruptedException e) {
- // do nothing
- }
- }
-
- @Override
- public StreamManager getStreamManager() {
- return defaultManager;
- }
- };
-
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
-
- // First request should eventually fail.
- TransportClient client0 =
- clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- TestCallback callback0 = new TestCallback();
- synchronized (callback0) {
- client0.sendRpc(ByteBuffer.allocate(0), callback0);
- callback0.wait(FOREVER);
- assert (callback0.failure instanceof IOException);
- assert (!client0.isActive());
- }
-
- // Increment the semaphore and the second request should succeed quickly.
- semaphore.release(2);
- TransportClient client1 =
- clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- TestCallback callback1 = new TestCallback();
- synchronized (callback1) {
- client1.sendRpc(ByteBuffer.allocate(0), callback1);
- callback1.wait(FOREVER);
- assertEquals(responseSize, callback1.successLength);
- assertNull(callback1.failure);
- }
- }
-
- // The timeout is relative to the LAST request sent, which is kinda weird, but still.
- // This test also makes sure the timeout works for Fetch requests as well as RPCs.
- @Test
- public void furtherRequestsDelay() throws Exception {
- final byte[] response = new byte[16];
- final StreamManager manager = new StreamManager() {
- @Override
- public ManagedBuffer getChunk(long streamId, int chunkIndex) {
- Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
- return new NioManagedBuffer(ByteBuffer.wrap(response));
- }
- };
- RpcHandler handler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public StreamManager getStreamManager() {
- return manager;
- }
- };
-
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-
- // Send one request, which will eventually fail.
- TestCallback callback0 = new TestCallback();
- client.fetchChunk(0, 0, callback0);
- Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
-
- // Send a second request before the first has failed.
- TestCallback callback1 = new TestCallback();
- client.fetchChunk(0, 1, callback1);
- Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
-
- synchronized (callback0) {
- // not complete yet, but should complete soon
- assertEquals(-1, callback0.successLength);
- assertNull(callback0.failure);
- callback0.wait(2 * 1000);
- assertTrue(callback0.failure instanceof IOException);
- }
-
- synchronized (callback1) {
- // failed at same time as previous
- assert (callback0.failure instanceof IOException);
- }
- }
-
- /**
- * Callback which sets 'success' or 'failure' on completion.
- * Additionally notifies all waiters on this callback when invoked.
- */
- class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
-
- int successLength = -1;
- Throwable failure;
-
- @Override
- public void onSuccess(ByteBuffer response) {
- synchronized(this) {
- successLength = response.remaining();
- this.notifyAll();
- }
- }
-
- @Override
- public void onFailure(Throwable e) {
- synchronized(this) {
- failure = e;
- this.notifyAll();
- }
- }
-
- @Override
- public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
- synchronized(this) {
- try {
- successLength = buffer.nioByteBuffer().remaining();
- this.notifyAll();
- } catch (IOException e) {
- // weird
- }
- }
- }
-
- @Override
- public void onFailure(int chunkIndex, Throwable e) {
- synchronized(this) {
- failure = e;
- this.notifyAll();
- }
- }
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
deleted file mode 100644
index 9e9be98c14..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ /dev/null
@@ -1,215 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Set;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
-
-import com.google.common.collect.Sets;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-import static org.junit.Assert.*;
-
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.OneForOneStreamManager;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.JavaUtils;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-
-public class RpcIntegrationSuite {
- static TransportServer server;
- static TransportClientFactory clientFactory;
- static RpcHandler rpcHandler;
- static List<String> oneWayMsgs;
-
- @BeforeClass
- public static void setUp() throws Exception {
- TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
- rpcHandler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- String msg = JavaUtils.bytesToString(message);
- String[] parts = msg.split("/");
- if (parts[0].equals("hello")) {
- callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!"));
- } else if (parts[0].equals("return error")) {
- callback.onFailure(new RuntimeException("Returned: " + parts[1]));
- } else if (parts[0].equals("throw error")) {
- throw new RuntimeException("Thrown: " + parts[1]);
- }
- }
-
- @Override
- public void receive(TransportClient client, ByteBuffer message) {
- oneWayMsgs.add(JavaUtils.bytesToString(message));
- }
-
- @Override
- public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
- };
- TransportContext context = new TransportContext(conf, rpcHandler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
- oneWayMsgs = new ArrayList<>();
- }
-
- @AfterClass
- public static void tearDown() {
- server.close();
- clientFactory.close();
- }
-
- class RpcResult {
- public Set<String> successMessages;
- public Set<String> errorMessages;
- }
-
- private RpcResult sendRPC(String ... commands) throws Exception {
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- final Semaphore sem = new Semaphore(0);
-
- final RpcResult res = new RpcResult();
- res.successMessages = Collections.synchronizedSet(new HashSet<String>());
- res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
-
- RpcResponseCallback callback = new RpcResponseCallback() {
- @Override
- public void onSuccess(ByteBuffer message) {
- String response = JavaUtils.bytesToString(message);
- res.successMessages.add(response);
- sem.release();
- }
-
- @Override
- public void onFailure(Throwable e) {
- res.errorMessages.add(e.getMessage());
- sem.release();
- }
- };
-
- for (String command : commands) {
- client.sendRpc(JavaUtils.stringToBytes(command), callback);
- }
-
- if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server");
- }
- client.close();
- return res;
- }
-
- @Test
- public void singleRPC() throws Exception {
- RpcResult res = sendRPC("hello/Aaron");
- assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
- assertTrue(res.errorMessages.isEmpty());
- }
-
- @Test
- public void doubleRPC() throws Exception {
- RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
- assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
- assertTrue(res.errorMessages.isEmpty());
- }
-
- @Test
- public void returnErrorRPC() throws Exception {
- RpcResult res = sendRPC("return error/OK");
- assertTrue(res.successMessages.isEmpty());
- assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
- }
-
- @Test
- public void throwErrorRPC() throws Exception {
- RpcResult res = sendRPC("throw error/uh-oh");
- assertTrue(res.successMessages.isEmpty());
- assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
- }
-
- @Test
- public void doubleTrouble() throws Exception {
- RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
- assertTrue(res.successMessages.isEmpty());
- assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
- }
-
- @Test
- public void sendSuccessAndFailure() throws Exception {
- RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
- assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
- assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
- }
-
- @Test
- public void sendOneWayMessage() throws Exception {
- final String message = "no reply";
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- try {
- client.send(JavaUtils.stringToBytes(message));
- assertEquals(0, client.getHandler().numOutstandingRequests());
-
- // Make sure the message arrives.
- long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
- while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
- TimeUnit.MILLISECONDS.sleep(10);
- }
-
- assertEquals(1, oneWayMsgs.size());
- assertEquals(message, oneWayMsgs.get(0));
- } finally {
- client.close();
- }
- }
-
- private void assertErrorsContain(Set<String> errors, Set<String> contains) {
- assertEquals(contains.size(), errors.size());
-
- Set<String> remainingErrors = Sets.newHashSet(errors);
- for (String contain : contains) {
- Iterator<String> it = remainingErrors.iterator();
- boolean foundMatch = false;
- while (it.hasNext()) {
- if (it.next().contains(contain)) {
- it.remove();
- foundMatch = true;
- break;
- }
- }
- assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
- }
-
- assertTrue(remainingErrors.isEmpty());
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
deleted file mode 100644
index 9c49556927..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStream;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.TimeUnit;
-
-import com.google.common.io.Files;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import static org.junit.Assert.*;
-
-import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NioManagedBuffer;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-
-public class StreamSuite {
- private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" };
-
- private static TransportServer server;
- private static TransportClientFactory clientFactory;
- private static File testFile;
- private static File tempDir;
-
- private static ByteBuffer emptyBuffer;
- private static ByteBuffer smallBuffer;
- private static ByteBuffer largeBuffer;
-
- private static ByteBuffer createBuffer(int bufSize) {
- ByteBuffer buf = ByteBuffer.allocate(bufSize);
- for (int i = 0; i < bufSize; i ++) {
- buf.put((byte) i);
- }
- buf.flip();
- return buf;
- }
-
- @BeforeClass
- public static void setUp() throws Exception {
- tempDir = Files.createTempDir();
- emptyBuffer = createBuffer(0);
- smallBuffer = createBuffer(100);
- largeBuffer = createBuffer(100000);
-
- testFile = File.createTempFile("stream-test-file", "txt", tempDir);
- FileOutputStream fp = new FileOutputStream(testFile);
- try {
- Random rnd = new Random();
- for (int i = 0; i < 512; i++) {
- byte[] fileContent = new byte[1024];
- rnd.nextBytes(fileContent);
- fp.write(fileContent);
- }
- } finally {
- fp.close();
- }
-
- final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
- final StreamManager streamManager = new StreamManager() {
- @Override
- public ManagedBuffer getChunk(long streamId, int chunkIndex) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public ManagedBuffer openStream(String streamId) {
- switch (streamId) {
- case "largeBuffer":
- return new NioManagedBuffer(largeBuffer);
- case "smallBuffer":
- return new NioManagedBuffer(smallBuffer);
- case "emptyBuffer":
- return new NioManagedBuffer(emptyBuffer);
- case "file":
- return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
- default:
- throw new IllegalArgumentException("Invalid stream: " + streamId);
- }
- }
- };
- RpcHandler handler = new RpcHandler() {
- @Override
- public void receive(
- TransportClient client,
- ByteBuffer message,
- RpcResponseCallback callback) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public StreamManager getStreamManager() {
- return streamManager;
- }
- };
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
- }
-
- @AfterClass
- public static void tearDown() {
- server.close();
- clientFactory.close();
- if (tempDir != null) {
- for (File f : tempDir.listFiles()) {
- f.delete();
- }
- tempDir.delete();
- }
- }
-
- @Test
- public void testZeroLengthStream() throws Throwable {
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- try {
- StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
- task.run();
- task.check();
- } finally {
- client.close();
- }
- }
-
- @Test
- public void testSingleStream() throws Throwable {
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- try {
- StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5));
- task.run();
- task.check();
- } finally {
- client.close();
- }
- }
-
- @Test
- public void testMultipleStreams() throws Throwable {
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
- try {
- for (int i = 0; i < 20; i++) {
- StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
- TimeUnit.SECONDS.toMillis(5));
- task.run();
- task.check();
- }
- } finally {
- client.close();
- }
- }
-
- @Test
- public void testConcurrentStreams() throws Throwable {
- ExecutorService executor = Executors.newFixedThreadPool(20);
- TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-
- try {
- List<StreamTask> tasks = new ArrayList<>();
- for (int i = 0; i < 20; i++) {
- StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
- TimeUnit.SECONDS.toMillis(20));
- tasks.add(task);
- executor.submit(task);
- }
-
- executor.shutdown();
- assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS));
- for (StreamTask task : tasks) {
- task.check();
- }
- } finally {
- executor.shutdownNow();
- client.close();
- }
- }
-
- private static class StreamTask implements Runnable {
-
- private final TransportClient client;
- private final String streamId;
- private final long timeoutMs;
- private Throwable error;
-
- StreamTask(TransportClient client, String streamId, long timeoutMs) {
- this.client = client;
- this.streamId = streamId;
- this.timeoutMs = timeoutMs;
- }
-
- @Override
- public void run() {
- ByteBuffer srcBuffer = null;
- OutputStream out = null;
- File outFile = null;
- try {
- ByteArrayOutputStream baos = null;
-
- switch (streamId) {
- case "largeBuffer":
- baos = new ByteArrayOutputStream();
- out = baos;
- srcBuffer = largeBuffer;
- break;
- case "smallBuffer":
- baos = new ByteArrayOutputStream();
- out = baos;
- srcBuffer = smallBuffer;
- break;
- case "file":
- outFile = File.createTempFile("data", ".tmp", tempDir);
- out = new FileOutputStream(outFile);
- break;
- case "emptyBuffer":
- baos = new ByteArrayOutputStream();
- out = baos;
- srcBuffer = emptyBuffer;
- break;
- default:
- throw new IllegalArgumentException(streamId);
- }
-
- TestCallback callback = new TestCallback(out);
- client.stream(streamId, callback);
- waitForCompletion(callback);
-
- if (srcBuffer == null) {
- assertTrue("File stream did not match.", Files.equal(testFile, outFile));
- } else {
- ByteBuffer base;
- synchronized (srcBuffer) {
- base = srcBuffer.duplicate();
- }
- byte[] result = baos.toByteArray();
- byte[] expected = new byte[base.remaining()];
- base.get(expected);
- assertEquals(expected.length, result.length);
- assertTrue("buffers don't match", Arrays.equals(expected, result));
- }
- } catch (Throwable t) {
- error = t;
- } finally {
- if (out != null) {
- try {
- out.close();
- } catch (Exception e) {
- // ignore.
- }
- }
- if (outFile != null) {
- outFile.delete();
- }
- }
- }
-
- public void check() throws Throwable {
- if (error != null) {
- throw error;
- }
- }
-
- private void waitForCompletion(TestCallback callback) throws Exception {
- long now = System.currentTimeMillis();
- long deadline = now + timeoutMs;
- synchronized (callback) {
- while (!callback.completed && now < deadline) {
- callback.wait(deadline - now);
- now = System.currentTimeMillis();
- }
- }
- assertTrue("Timed out waiting for stream.", callback.completed);
- assertNull(callback.error);
- }
-
- }
-
- private static class TestCallback implements StreamCallback {
-
- private final OutputStream out;
- public volatile boolean completed;
- public volatile Throwable error;
-
- TestCallback(OutputStream out) {
- this.out = out;
- this.completed = false;
- }
-
- @Override
- public void onData(String streamId, ByteBuffer buf) throws IOException {
- byte[] tmp = new byte[buf.remaining()];
- buf.get(tmp);
- out.write(tmp);
- }
-
- @Override
- public void onComplete(String streamId) throws IOException {
- out.close();
- synchronized (this) {
- completed = true;
- notifyAll();
- }
- }
-
- @Override
- public void onFailure(String streamId, Throwable cause) {
- error = cause;
- synchronized (this) {
- completed = true;
- notifyAll();
- }
- }
-
- }
-
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
deleted file mode 100644
index 83c90f9eff..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.ByteBuffer;
-
-import com.google.common.base.Preconditions;
-import io.netty.buffer.Unpooled;
-
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NettyManagedBuffer;
-
-/**
- * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
- *
- * Used for testing.
- */
-public class TestManagedBuffer extends ManagedBuffer {
-
- private final int len;
- private NettyManagedBuffer underlying;
-
- public TestManagedBuffer(int len) {
- Preconditions.checkArgument(len <= Byte.MAX_VALUE);
- this.len = len;
- byte[] byteArray = new byte[len];
- for (int i = 0; i < len; i ++) {
- byteArray[i] = (byte) i;
- }
- this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray));
- }
-
-
- @Override
- public long size() {
- return underlying.size();
- }
-
- @Override
- public ByteBuffer nioByteBuffer() throws IOException {
- return underlying.nioByteBuffer();
- }
-
- @Override
- public InputStream createInputStream() throws IOException {
- return underlying.createInputStream();
- }
-
- @Override
- public ManagedBuffer retain() {
- underlying.retain();
- return this;
- }
-
- @Override
- public ManagedBuffer release() {
- underlying.release();
- return this;
- }
-
- @Override
- public Object convertToNetty() throws IOException {
- return underlying.convertToNetty();
- }
-
- @Override
- public int hashCode() {
- return underlying.hashCode();
- }
-
- @Override
- public boolean equals(Object other) {
- if (other instanceof ManagedBuffer) {
- try {
- ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer();
- if (nioBuf.remaining() != len) {
- return false;
- } else {
- for (int i = 0; i < len; i ++) {
- if (nioBuf.get() != i) {
- return false;
- }
- }
- return true;
- }
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
- return false;
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
deleted file mode 100644
index 56a2b805f1..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/TestUtils.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.net.InetAddress;
-
-public class TestUtils {
- public static String getLocalHost() {
- try {
- return InetAddress.getLocalHost().getHostAddress();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
deleted file mode 100644
index dac7d4a5b0..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ /dev/null
@@ -1,214 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.io.IOException;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.NoSuchElementException;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import com.google.common.collect.Maps;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.NoOpRpcHandler;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.ConfigProvider;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
-import org.apache.spark.network.util.JavaUtils;
-import org.apache.spark.network.util.MapConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-
-public class TransportClientFactorySuite {
- private TransportConf conf;
- private TransportContext context;
- private TransportServer server1;
- private TransportServer server2;
-
- @Before
- public void setUp() {
- conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
- RpcHandler rpcHandler = new NoOpRpcHandler();
- context = new TransportContext(conf, rpcHandler);
- server1 = context.createServer();
- server2 = context.createServer();
- }
-
- @After
- public void tearDown() {
- JavaUtils.closeQuietly(server1);
- JavaUtils.closeQuietly(server2);
- }
-
- /**
- * Request a bunch of clients to a single server to test
- * we create up to maxConnections of clients.
- *
- * If concurrent is true, create multiple threads to create clients in parallel.
- */
- private void testClientReuse(final int maxConnections, boolean concurrent)
- throws IOException, InterruptedException {
-
- Map<String, String> configMap = Maps.newHashMap();
- configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
- TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap));
-
- RpcHandler rpcHandler = new NoOpRpcHandler();
- TransportContext context = new TransportContext(conf, rpcHandler);
- final TransportClientFactory factory = context.createClientFactory();
- final Set<TransportClient> clients = Collections.synchronizedSet(
- new HashSet<TransportClient>());
-
- final AtomicInteger failed = new AtomicInteger();
- Thread[] attempts = new Thread[maxConnections * 10];
-
- // Launch a bunch of threads to create new clients.
- for (int i = 0; i < attempts.length; i++) {
- attempts[i] = new Thread() {
- @Override
- public void run() {
- try {
- TransportClient client =
- factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- assert (client.isActive());
- clients.add(client);
- } catch (IOException e) {
- failed.incrementAndGet();
- }
- }
- };
-
- if (concurrent) {
- attempts[i].start();
- } else {
- attempts[i].run();
- }
- }
-
- // Wait until all the threads complete.
- for (int i = 0; i < attempts.length; i++) {
- attempts[i].join();
- }
-
- assert(failed.get() == 0);
- assert(clients.size() == maxConnections);
-
- for (TransportClient client : clients) {
- client.close();
- }
- }
-
- @Test
- public void reuseClientsUpToConfigVariable() throws Exception {
- testClientReuse(1, false);
- testClientReuse(2, false);
- testClientReuse(3, false);
- testClientReuse(4, false);
- }
-
- @Test
- public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
- testClientReuse(1, true);
- testClientReuse(2, true);
- testClientReuse(3, true);
- testClientReuse(4, true);
- }
-
- @Test
- public void returnDifferentClientsForDifferentServers() throws IOException {
- TransportClientFactory factory = context.createClientFactory();
- TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
- assertTrue(c1.isActive());
- assertTrue(c2.isActive());
- assertTrue(c1 != c2);
- factory.close();
- }
-
- @Test
- public void neverReturnInactiveClients() throws IOException, InterruptedException {
- TransportClientFactory factory = context.createClientFactory();
- TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- c1.close();
-
- long start = System.currentTimeMillis();
- while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
- Thread.sleep(10);
- }
- assertFalse(c1.isActive());
-
- TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- assertFalse(c1 == c2);
- assertTrue(c2.isActive());
- factory.close();
- }
-
- @Test
- public void closeBlockClientsWithFactory() throws IOException {
- TransportClientFactory factory = context.createClientFactory();
- TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
- assertTrue(c1.isActive());
- assertTrue(c2.isActive());
- factory.close();
- assertFalse(c1.isActive());
- assertFalse(c2.isActive());
- }
-
- @Test
- public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException {
- TransportConf conf = new TransportConf("shuffle", new ConfigProvider() {
-
- @Override
- public String get(String name) {
- if ("spark.shuffle.io.connectionTimeout".equals(name)) {
- // We should make sure there is enough time for us to observe the channel is active
- return "1s";
- }
- String value = System.getProperty(name);
- if (value == null) {
- throw new NoSuchElementException(name);
- }
- return value;
- }
- });
- TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
- TransportClientFactory factory = context.createClientFactory();
- try {
- TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
- assertTrue(c1.isActive());
- long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds
- while (c1.isActive() && System.currentTimeMillis() < expiredTime) {
- Thread.sleep(10);
- }
- assertFalse(c1.isActive());
- } finally {
- factory.close();
- }
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
deleted file mode 100644
index 128f7cba74..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.nio.ByteBuffer;
-
-import io.netty.channel.Channel;
-import io.netty.channel.local.LocalChannel;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.*;
-
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NioManagedBuffer;
-import org.apache.spark.network.client.ChunkReceivedCallback;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.StreamCallback;
-import org.apache.spark.network.client.TransportResponseHandler;
-import org.apache.spark.network.protocol.ChunkFetchFailure;
-import org.apache.spark.network.protocol.ChunkFetchSuccess;
-import org.apache.spark.network.protocol.RpcFailure;
-import org.apache.spark.network.protocol.RpcResponse;
-import org.apache.spark.network.protocol.StreamChunkId;
-import org.apache.spark.network.protocol.StreamFailure;
-import org.apache.spark.network.protocol.StreamResponse;
-import org.apache.spark.network.util.TransportFrameDecoder;
-
-public class TransportResponseHandlerSuite {
- @Test
- public void handleSuccessfulFetch() throws Exception {
- StreamChunkId streamChunkId = new StreamChunkId(1, 0);
-
- TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- handler.addFetchRequest(streamChunkId, callback);
- assertEquals(1, handler.numOutstandingRequests());
-
- handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
- verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
- assertEquals(0, handler.numOutstandingRequests());
- }
-
- @Test
- public void handleFailedFetch() throws Exception {
- StreamChunkId streamChunkId = new StreamChunkId(1, 0);
- TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- handler.addFetchRequest(streamChunkId, callback);
- assertEquals(1, handler.numOutstandingRequests());
-
- handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
- verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
- assertEquals(0, handler.numOutstandingRequests());
- }
-
- @Test
- public void clearAllOutstandingRequests() throws Exception {
- TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- handler.addFetchRequest(new StreamChunkId(1, 0), callback);
- handler.addFetchRequest(new StreamChunkId(1, 1), callback);
- handler.addFetchRequest(new StreamChunkId(1, 2), callback);
- assertEquals(3, handler.numOutstandingRequests());
-
- handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
- handler.exceptionCaught(new Exception("duh duh duhhhh"));
-
- // should fail both b2 and b3
- verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
- verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
- verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
- assertEquals(0, handler.numOutstandingRequests());
- }
-
- @Test
- public void handleSuccessfulRPC() throws Exception {
- TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
- RpcResponseCallback callback = mock(RpcResponseCallback.class);
- handler.addRpcRequest(12345, callback);
- assertEquals(1, handler.numOutstandingRequests());
-
- // This response should be ignored.
- handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
- assertEquals(1, handler.numOutstandingRequests());
-
- ByteBuffer resp = ByteBuffer.allocate(10);
- handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
- verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
- assertEquals(0, handler.numOutstandingRequests());
- }
-
- @Test
- public void handleFailedRPC() throws Exception {
- TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
- RpcResponseCallback callback = mock(RpcResponseCallback.class);
- handler.addRpcRequest(12345, callback);
- assertEquals(1, handler.numOutstandingRequests());
-
- handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
- assertEquals(1, handler.numOutstandingRequests());
-
- handler.handle(new RpcFailure(12345, "oh no"));
- verify(callback, times(1)).onFailure((Throwable) any());
- assertEquals(0, handler.numOutstandingRequests());
- }
-
- @Test
- public void testActiveStreams() throws Exception {
- Channel c = new LocalChannel();
- c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
- TransportResponseHandler handler = new TransportResponseHandler(c);
-
- StreamResponse response = new StreamResponse("stream", 1234L, null);
- StreamCallback cb = mock(StreamCallback.class);
- handler.addStreamCallback(cb);
- assertEquals(1, handler.numOutstandingRequests());
- handler.handle(response);
- assertEquals(1, handler.numOutstandingRequests());
- handler.deactivateStream();
- assertEquals(0, handler.numOutstandingRequests());
-
- StreamFailure failure = new StreamFailure("stream", "uh-oh");
- handler.addStreamCallback(cb);
- assertEquals(1, handler.numOutstandingRequests());
- handler.handle(failure);
- assertEquals(0, handler.numOutstandingRequests());
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
deleted file mode 100644
index fbbe4b7014..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.protocol;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.channels.WritableByteChannel;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.FileRegion;
-import io.netty.util.AbstractReferenceCounted;
-import org.junit.Test;
-import org.mockito.Mockito;
-
-import static org.junit.Assert.*;
-
-import org.apache.spark.network.TestManagedBuffer;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NettyManagedBuffer;
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-
-public class MessageWithHeaderSuite {
-
- @Test
- public void testSingleWrite() throws Exception {
- testFileRegionBody(8, 8);
- }
-
- @Test
- public void testShortWrite() throws Exception {
- testFileRegionBody(8, 1);
- }
-
- @Test
- public void testByteBufBody() throws Exception {
- ByteBuf header = Unpooled.copyLong(42);
- ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
- assertEquals(1, header.refCnt());
- assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
- ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
-
- Object body = managedBuf.convertToNetty();
- assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
- assertEquals(1, header.refCnt());
-
- MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
- ByteBuf result = doWrite(msg, 1);
- assertEquals(msg.count(), result.readableBytes());
- assertEquals(42, result.readLong());
- assertEquals(84, result.readLong());
-
- assert(msg.release());
- assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
- assertEquals(0, header.refCnt());
- }
-
- @Test
- public void testDeallocateReleasesManagedBuffer() throws Exception {
- ByteBuf header = Unpooled.copyLong(42);
- ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
- ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
- assertEquals(2, body.refCnt());
- MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
- assert(msg.release());
- Mockito.verify(managedBuf, Mockito.times(1)).release();
- assertEquals(0, body.refCnt());
- }
-
- private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
- ByteBuf header = Unpooled.copyLong(42);
- int headerLength = header.readableBytes();
- TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
- MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
-
- ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
- assertEquals(headerLength + region.count(), result.readableBytes());
- assertEquals(42, result.readLong());
- for (long i = 0; i < 8; i++) {
- assertEquals(i, result.readLong());
- }
- assert(msg.release());
- }
-
- private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
- int writes = 0;
- ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
- while (msg.transfered() < msg.count()) {
- msg.transferTo(channel, msg.transfered());
- writes++;
- }
- assertTrue("Not enough writes!", minExpectedWrites <= writes);
- return Unpooled.wrappedBuffer(channel.getData());
- }
-
- private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
-
- private final int writeCount;
- private final int writesPerCall;
- private int written;
-
- TestFileRegion(int totalWrites, int writesPerCall) {
- this.writeCount = totalWrites;
- this.writesPerCall = writesPerCall;
- }
-
- @Override
- public long count() {
- return 8 * writeCount;
- }
-
- @Override
- public long position() {
- return 0;
- }
-
- @Override
- public long transfered() {
- return 8 * written;
- }
-
- @Override
- public long transferTo(WritableByteChannel target, long position) throws IOException {
- for (int i = 0; i < writesPerCall; i++) {
- ByteBuf buf = Unpooled.copyLong((position / 8) + i);
- ByteBuffer nio = buf.nioBuffer();
- while (nio.remaining() > 0) {
- target.write(nio);
- }
- buf.release();
- written++;
- }
- return 8 * writesPerCall;
- }
-
- @Override
- protected void deallocate() {
- }
-
- }
-
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
deleted file mode 100644
index 045773317a..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ /dev/null
@@ -1,476 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.sasl;
-
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-import java.io.File;
-import java.lang.reflect.Method;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
-import java.util.concurrent.TimeoutException;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
-import javax.security.sasl.SaslException;
-
-import com.google.common.collect.Lists;
-import com.google.common.io.ByteStreams;
-import com.google.common.io.Files;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelOutboundHandlerAdapter;
-import io.netty.channel.ChannelPromise;
-import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import org.apache.spark.network.TestUtils;
-import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.client.ChunkReceivedCallback;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.server.TransportServerBootstrap;
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.JavaUtils;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
-import org.apache.spark.network.util.TransportConf;
-
-/**
- * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
- */
-public class SparkSaslSuite {
-
- /** Provides a secret key holder which returns secret key == appId */
- private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
- @Override
- public String getSaslUser(String appId) {
- return "user";
- }
-
- @Override
- public String getSecretKey(String appId) {
- return appId;
- }
- };
-
- @Test
- public void testMatching() {
- SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
- SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
-
- assertFalse(client.isComplete());
- assertFalse(server.isComplete());
-
- byte[] clientMessage = client.firstToken();
-
- while (!client.isComplete()) {
- clientMessage = client.response(server.response(clientMessage));
- }
- assertTrue(server.isComplete());
-
- // Disposal should invalidate
- server.dispose();
- assertFalse(server.isComplete());
- client.dispose();
- assertFalse(client.isComplete());
- }
-
- @Test
- public void testNonMatching() {
- SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
- SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
-
- assertFalse(client.isComplete());
- assertFalse(server.isComplete());
-
- byte[] clientMessage = client.firstToken();
-
- try {
- while (!client.isComplete()) {
- clientMessage = client.response(server.response(clientMessage));
- }
- fail("Should not have completed");
- } catch (Exception e) {
- assertTrue(e.getMessage().contains("Mismatched response"));
- assertFalse(client.isComplete());
- assertFalse(server.isComplete());
- }
- }
-
- @Test
- public void testSaslAuthentication() throws Throwable {
- testBasicSasl(false);
- }
-
- @Test
- public void testSaslEncryption() throws Throwable {
- testBasicSasl(true);
- }
-
- private void testBasicSasl(boolean encrypt) throws Throwable {
- RpcHandler rpcHandler = mock(RpcHandler.class);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) {
- ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
- RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
- assertEquals("Ping", JavaUtils.bytesToString(message));
- cb.onSuccess(JavaUtils.stringToBytes("Pong"));
- return null;
- }
- })
- .when(rpcHandler)
- .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
-
- SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
- try {
- ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
- TimeUnit.SECONDS.toMillis(10));
- assertEquals("Pong", JavaUtils.bytesToString(response));
- } finally {
- ctx.close();
- // There should be 2 terminated events; one for the client, one for the server.
- Throwable error = null;
- long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
- while (deadline > System.nanoTime()) {
- try {
- verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class));
- error = null;
- break;
- } catch (Throwable t) {
- error = t;
- TimeUnit.MILLISECONDS.sleep(10);
- }
- }
- if (error != null) {
- throw error;
- }
- }
- }
-
- @Test
- public void testEncryptedMessage() throws Exception {
- SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
- byte[] data = new byte[1024];
- new Random().nextBytes(data);
- when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
-
- ByteBuf msg = Unpooled.buffer();
- try {
- msg.writeBytes(data);
-
- // Create a channel with a really small buffer compared to the data. This means that on each
- // call, the outbound data will not be fully written, so the write() method should return a
- // dummy count to keep the channel alive when possible.
- ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
-
- SaslEncryption.EncryptedMessage emsg =
- new SaslEncryption.EncryptedMessage(backend, msg, 1024);
- long count = emsg.transferTo(channel, emsg.transfered());
- assertTrue(count < data.length);
- assertTrue(count > 0);
-
- // Here, the output buffer is full so nothing should be transferred.
- assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
-
- // Now there's room in the buffer, but not enough to transfer all the remaining data,
- // so the dummy count should be returned.
- channel.reset();
- assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
-
- // Eventually, the whole message should be transferred.
- for (int i = 0; i < data.length / 32 - 2; i++) {
- channel.reset();
- assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
- }
-
- channel.reset();
- count = emsg.transferTo(channel, emsg.transfered());
- assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
- assertEquals(data.length, emsg.transfered());
- } finally {
- msg.release();
- }
- }
-
- @Test
- public void testEncryptedMessageChunking() throws Exception {
- File file = File.createTempFile("sasltest", ".txt");
- try {
- TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
-
- byte[] data = new byte[8 * 1024];
- new Random().nextBytes(data);
- Files.write(data, file);
-
- SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
- // It doesn't really matter what we return here, as long as it's not null.
- when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
-
- FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
- SaslEncryption.EncryptedMessage emsg =
- new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
-
- ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
- while (emsg.transfered() < emsg.count()) {
- channel.reset();
- emsg.transferTo(channel, emsg.transfered());
- }
-
- verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
- } finally {
- file.delete();
- }
- }
-
- @Test
- public void testFileRegionEncryption() throws Exception {
- final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
- System.setProperty(blockSizeConf, "1k");
-
- final AtomicReference<ManagedBuffer> response = new AtomicReference<>();
- final File file = File.createTempFile("sasltest", ".txt");
- SaslTestCtx ctx = null;
- try {
- final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
- StreamManager sm = mock(StreamManager.class);
- when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
- @Override
- public ManagedBuffer answer(InvocationOnMock invocation) {
- return new FileSegmentManagedBuffer(conf, file, 0, file.length());
- }
- });
-
- RpcHandler rpcHandler = mock(RpcHandler.class);
- when(rpcHandler.getStreamManager()).thenReturn(sm);
-
- byte[] data = new byte[8 * 1024];
- new Random().nextBytes(data);
- Files.write(data, file);
-
- ctx = new SaslTestCtx(rpcHandler, true, false);
-
- final Object lock = new Object();
-
- ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) {
- response.set((ManagedBuffer) invocation.getArguments()[1]);
- response.get().retain();
- synchronized (lock) {
- lock.notifyAll();
- }
- return null;
- }
- }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
-
- synchronized (lock) {
- ctx.client.fetchChunk(0, 0, callback);
- lock.wait(10 * 1000);
- }
-
- verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
- verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
-
- byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
- assertTrue(Arrays.equals(data, received));
- } finally {
- file.delete();
- if (ctx != null) {
- ctx.close();
- }
- if (response.get() != null) {
- response.get().release();
- }
- System.clearProperty(blockSizeConf);
- }
- }
-
- @Test
- public void testServerAlwaysEncrypt() throws Exception {
- final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
- System.setProperty(alwaysEncryptConfName, "true");
-
- SaslTestCtx ctx = null;
- try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
- fail("Should have failed to connect without encryption.");
- } catch (Exception e) {
- assertTrue(e.getCause() instanceof SaslException);
- } finally {
- if (ctx != null) {
- ctx.close();
- }
- System.clearProperty(alwaysEncryptConfName);
- }
- }
-
- @Test
- public void testDataEncryptionIsActuallyEnabled() throws Exception {
- // This test sets up an encrypted connection but then, using a client bootstrap, removes
- // the encryption handler from the client side. This should cause the server to not be
- // able to understand RPCs sent to it and thus close the connection.
- SaslTestCtx ctx = null;
- try {
- ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
- ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
- TimeUnit.SECONDS.toMillis(10));
- fail("Should have failed to send RPC to server.");
- } catch (Exception e) {
- assertFalse(e.getCause() instanceof TimeoutException);
- } finally {
- if (ctx != null) {
- ctx.close();
- }
- }
- }
-
- @Test
- public void testRpcHandlerDelegate() throws Exception {
- // Tests all delegates exception for receive(), which is more complicated and already handled
- // by all other tests.
- RpcHandler handler = mock(RpcHandler.class);
- RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null);
-
- saslHandler.getStreamManager();
- verify(handler).getStreamManager();
-
- saslHandler.channelInactive(null);
- verify(handler).channelInactive(any(TransportClient.class));
-
- saslHandler.exceptionCaught(null, null);
- verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class));
- }
-
- @Test
- public void testDelegates() throws Exception {
- Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
- for (Method m : rpcHandlerMethods) {
- SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
- }
- }
-
- private static class SaslTestCtx {
-
- final TransportClient client;
- final TransportServer server;
-
- private final boolean encrypt;
- private final boolean disableClientEncryption;
- private final EncryptionCheckerBootstrap checker;
-
- SaslTestCtx(
- RpcHandler rpcHandler,
- boolean encrypt,
- boolean disableClientEncryption)
- throws Exception {
-
- TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider());
-
- SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
- when(keyHolder.getSaslUser(anyString())).thenReturn("user");
- when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
-
- TransportContext ctx = new TransportContext(conf, rpcHandler);
-
- this.checker = new EncryptionCheckerBootstrap();
- this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
- checker));
-
- try {
- List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
- clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
- if (disableClientEncryption) {
- clientBootstraps.add(new EncryptionDisablerBootstrap());
- }
-
- this.client = ctx.createClientFactory(clientBootstraps)
- .createClient(TestUtils.getLocalHost(), server.getPort());
- } catch (Exception e) {
- close();
- throw e;
- }
-
- this.encrypt = encrypt;
- this.disableClientEncryption = disableClientEncryption;
- }
-
- void close() {
- if (!disableClientEncryption) {
- assertEquals(encrypt, checker.foundEncryptionHandler);
- }
- if (client != null) {
- client.close();
- }
- if (server != null) {
- server.close();
- }
- }
-
- }
-
- private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
- implements TransportServerBootstrap {
-
- boolean foundEncryptionHandler;
-
- @Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
- throws Exception {
- if (!foundEncryptionHandler) {
- foundEncryptionHandler =
- ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
- }
- ctx.write(msg, promise);
- }
-
- @Override
- public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
- super.handlerRemoved(ctx);
- }
-
- @Override
- public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
- channel.pipeline().addFirst("encryptionChecker", this);
- return rpcHandler;
- }
-
- }
-
- private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
-
- @Override
- public void doBootstrap(TransportClient client, Channel channel) {
- channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
- }
-
- }
-
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
deleted file mode 100644
index c647525d8f..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.server;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import io.netty.channel.Channel;
-import org.junit.Test;
-import org.mockito.Mockito;
-
-import org.apache.spark.network.TestManagedBuffer;
-import org.apache.spark.network.buffer.ManagedBuffer;
-
-public class OneForOneStreamManagerSuite {
-
- @Test
- public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
- OneForOneStreamManager manager = new OneForOneStreamManager();
- List<ManagedBuffer> buffers = new ArrayList<>();
- TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
- TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
- buffers.add(buffer1);
- buffers.add(buffer2);
- long streamId = manager.registerStream("appId", buffers.iterator());
-
- Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
- manager.registerChannel(dummyChannel, streamId);
-
- manager.connectionTerminated(dummyChannel);
-
- Mockito.verify(buffer1, Mockito.times(1)).release();
- Mockito.verify(buffer2, Mockito.times(1)).release();
- }
-}
diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
deleted file mode 100644
index d4de4a941d..0000000000
--- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ /dev/null
@@ -1,258 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.util;
-
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelHandlerContext;
-import org.junit.AfterClass;
-import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-public class TransportFrameDecoderSuite {
-
- private static Random RND = new Random();
-
- @AfterClass
- public static void cleanup() {
- RND = null;
- }
-
- @Test
- public void testFrameDecoding() throws Exception {
- TransportFrameDecoder decoder = new TransportFrameDecoder();
- ChannelHandlerContext ctx = mockChannelHandlerContext();
- ByteBuf data = createAndFeedFrames(100, decoder, ctx);
- verifyAndCloseDecoder(decoder, ctx, data);
- }
-
- @Test
- public void testInterception() throws Exception {
- final int interceptedReads = 3;
- TransportFrameDecoder decoder = new TransportFrameDecoder();
- TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
- ChannelHandlerContext ctx = mockChannelHandlerContext();
-
- byte[] data = new byte[8];
- ByteBuf len = Unpooled.copyLong(8 + data.length);
- ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
-
- try {
- decoder.setInterceptor(interceptor);
- for (int i = 0; i < interceptedReads; i++) {
- decoder.channelRead(ctx, dataBuf);
- assertEquals(0, dataBuf.refCnt());
- dataBuf = Unpooled.wrappedBuffer(data);
- }
- decoder.channelRead(ctx, len);
- decoder.channelRead(ctx, dataBuf);
- verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
- verify(ctx).fireChannelRead(any(ByteBuffer.class));
- assertEquals(0, len.refCnt());
- assertEquals(0, dataBuf.refCnt());
- } finally {
- release(len);
- release(dataBuf);
- }
- }
-
- @Test
- public void testRetainedFrames() throws Exception {
- TransportFrameDecoder decoder = new TransportFrameDecoder();
-
- final AtomicInteger count = new AtomicInteger();
- final List<ByteBuf> retained = new ArrayList<>();
-
- ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
- when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock in) {
- // Retain a few frames but not others.
- ByteBuf buf = (ByteBuf) in.getArguments()[0];
- if (count.incrementAndGet() % 2 == 0) {
- retained.add(buf);
- } else {
- buf.release();
- }
- return null;
- }
- });
-
- ByteBuf data = createAndFeedFrames(100, decoder, ctx);
- try {
- // Verify all retained buffers are readable.
- for (ByteBuf b : retained) {
- byte[] tmp = new byte[b.readableBytes()];
- b.readBytes(tmp);
- b.release();
- }
- verifyAndCloseDecoder(decoder, ctx, data);
- } finally {
- for (ByteBuf b : retained) {
- release(b);
- }
- }
- }
-
- @Test
- public void testSplitLengthField() throws Exception {
- byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
- ByteBuf buf = Unpooled.buffer(frame.length + 8);
- buf.writeLong(frame.length + 8);
- buf.writeBytes(frame);
-
- TransportFrameDecoder decoder = new TransportFrameDecoder();
- ChannelHandlerContext ctx = mockChannelHandlerContext();
- try {
- decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
- verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
- decoder.channelRead(ctx, buf);
- verify(ctx).fireChannelRead(any(ByteBuf.class));
- assertEquals(0, buf.refCnt());
- } finally {
- decoder.channelInactive(ctx);
- release(buf);
- }
- }
-
- @Test(expected = IllegalArgumentException.class)
- public void testNegativeFrameSize() throws Exception {
- testInvalidFrame(-1);
- }
-
- @Test(expected = IllegalArgumentException.class)
- public void testEmptyFrame() throws Exception {
- // 8 because frame size includes the frame length.
- testInvalidFrame(8);
- }
-
- @Test(expected = IllegalArgumentException.class)
- public void testLargeFrame() throws Exception {
- // Frame length includes the frame size field, so need to add a few more bytes.
- testInvalidFrame(Integer.MAX_VALUE + 9);
- }
-
- /**
- * Creates a number of randomly sized frames and feed them to the given decoder, verifying
- * that the frames were read.
- */
- private ByteBuf createAndFeedFrames(
- int frameCount,
- TransportFrameDecoder decoder,
- ChannelHandlerContext ctx) throws Exception {
- ByteBuf data = Unpooled.buffer();
- for (int i = 0; i < frameCount; i++) {
- byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
- data.writeLong(frame.length + 8);
- data.writeBytes(frame);
- }
-
- try {
- while (data.isReadable()) {
- int size = RND.nextInt(4 * 1024) + 256;
- decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
- }
-
- verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
- } catch (Exception e) {
- release(data);
- throw e;
- }
- return data;
- }
-
- private void verifyAndCloseDecoder(
- TransportFrameDecoder decoder,
- ChannelHandlerContext ctx,
- ByteBuf data) throws Exception {
- try {
- decoder.channelInactive(ctx);
- assertTrue("There shouldn't be dangling references to the data.", data.release());
- } finally {
- release(data);
- }
- }
-
- private void testInvalidFrame(long size) throws Exception {
- TransportFrameDecoder decoder = new TransportFrameDecoder();
- ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
- ByteBuf frame = Unpooled.copyLong(size);
- try {
- decoder.channelRead(ctx, frame);
- } finally {
- release(frame);
- }
- }
-
- private ChannelHandlerContext mockChannelHandlerContext() {
- ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
- when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock in) {
- ByteBuf buf = (ByteBuf) in.getArguments()[0];
- buf.release();
- return null;
- }
- });
- return ctx;
- }
-
- private void release(ByteBuf buf) {
- if (buf.refCnt() > 0) {
- buf.release(buf.refCnt());
- }
- }
-
- private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
-
- private int remainingReads;
-
- MockInterceptor(int readCount) {
- this.remainingReads = readCount;
- }
-
- @Override
- public boolean handle(ByteBuf data) throws Exception {
- data.readerIndex(data.readerIndex() + data.readableBytes());
- assertFalse(data.isReadable());
- remainingReads -= 1;
- return remainingReads != 0;
- }
-
- @Override
- public void exceptionCaught(Throwable cause) throws Exception {
-
- }
-
- @Override
- public void channelInactive() throws Exception {
-
- }
-
- }
-
-}