aboutsummaryrefslogtreecommitdiff
path: root/common/network-common
diff options
context:
space:
mode:
Diffstat (limited to 'common/network-common')
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java94
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java13
2 files changed, 44 insertions, 63 deletions
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index c0ff9dc5f5..dd0171d1d1 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -36,6 +36,7 @@ import static org.junit.Assert.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
@@ -93,7 +94,7 @@ public class RequestTimeoutIntegrationSuite {
ByteBuffer message,
RpcResponseCallback callback) {
try {
- semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ semaphore.acquire();
callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
@@ -113,20 +114,17 @@ public class RequestTimeoutIntegrationSuite {
// First completes quickly (semaphore starts at 1).
TestCallback callback0 = new TestCallback();
- synchronized (callback0) {
- client.sendRpc(ByteBuffer.allocate(0), callback0);
- callback0.wait(FOREVER);
- assertEquals(responseSize, callback0.successLength);
- }
+ client.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.latch.await();
+ assertEquals(responseSize, callback0.successLength);
// Second times out after 2 seconds, with slack. Must be IOException.
TestCallback callback1 = new TestCallback();
- synchronized (callback1) {
- client.sendRpc(ByteBuffer.allocate(0), callback1);
- callback1.wait(4 * 1000);
- assertNotNull(callback1.failure);
- assertTrue(callback1.failure instanceof IOException);
- }
+ client.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.latch.await(4, TimeUnit.SECONDS);
+ assertNotNull(callback1.failure);
+ assertTrue(callback1.failure instanceof IOException);
+
semaphore.release();
}
@@ -143,7 +141,7 @@ public class RequestTimeoutIntegrationSuite {
ByteBuffer message,
RpcResponseCallback callback) {
try {
- semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+ semaphore.acquire();
callback.onSuccess(ByteBuffer.allocate(responseSize));
} catch (InterruptedException e) {
// do nothing
@@ -164,24 +162,20 @@ public class RequestTimeoutIntegrationSuite {
TransportClient client0 =
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback0 = new TestCallback();
- synchronized (callback0) {
- client0.sendRpc(ByteBuffer.allocate(0), callback0);
- callback0.wait(FOREVER);
- assertTrue(callback0.failure instanceof IOException);
- assertFalse(client0.isActive());
- }
+ client0.sendRpc(ByteBuffer.allocate(0), callback0);
+ callback0.latch.await();
+ assertTrue(callback0.failure instanceof IOException);
+ assertFalse(client0.isActive());
// Increment the semaphore and the second request should succeed quickly.
semaphore.release(2);
TransportClient client1 =
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
TestCallback callback1 = new TestCallback();
- synchronized (callback1) {
- client1.sendRpc(ByteBuffer.allocate(0), callback1);
- callback1.wait(FOREVER);
- assertEquals(responseSize, callback1.successLength);
- assertNull(callback1.failure);
- }
+ client1.sendRpc(ByteBuffer.allocate(0), callback1);
+ callback1.latch.await();
+ assertEquals(responseSize, callback1.successLength);
+ assertNull(callback1.failure);
}
// The timeout is relative to the LAST request sent, which is kinda weird, but still.
@@ -226,18 +220,14 @@ public class RequestTimeoutIntegrationSuite {
client.fetchChunk(0, 1, callback1);
Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
- synchronized (callback0) {
- // not complete yet, but should complete soon
- assertEquals(-1, callback0.successLength);
- assertNull(callback0.failure);
- callback0.wait(2 * 1000);
- assertTrue(callback0.failure instanceof IOException);
- }
+ // not complete yet, but should complete soon
+ assertEquals(-1, callback0.successLength);
+ assertNull(callback0.failure);
+ callback0.latch.await(2, TimeUnit.SECONDS);
+ assertTrue(callback0.failure instanceof IOException);
- synchronized (callback1) {
- // failed at same time as previous
- assertTrue(callback0.failure instanceof IOException);
- }
+ // failed at same time as previous
+ assertTrue(callback1.failure instanceof IOException);
}
/**
@@ -248,41 +238,35 @@ public class RequestTimeoutIntegrationSuite {
int successLength = -1;
Throwable failure;
+ final CountDownLatch latch = new CountDownLatch(1);
@Override
public void onSuccess(ByteBuffer response) {
- synchronized(this) {
- successLength = response.remaining();
- this.notifyAll();
- }
+ successLength = response.remaining();
+ latch.countDown();
}
@Override
public void onFailure(Throwable e) {
- synchronized(this) {
- failure = e;
- this.notifyAll();
- }
+ failure = e;
+ latch.countDown();
}
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
- synchronized(this) {
- try {
- successLength = buffer.nioByteBuffer().remaining();
- this.notifyAll();
- } catch (IOException e) {
- // weird
- }
+ try {
+ successLength = buffer.nioByteBuffer().remaining();
+ } catch (IOException e) {
+ // weird
+ } finally {
+ latch.countDown();
}
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
- synchronized(this) {
- failure = e;
- this.notifyAll();
- }
+ failure = e;
+ latch.countDown();
}
}
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 045773317a..45cc03df43 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -26,6 +26,7 @@ import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
@@ -276,7 +277,7 @@ public class SparkSaslSuite {
ctx = new SaslTestCtx(rpcHandler, true, false);
- final Object lock = new Object();
+ final CountDownLatch lock = new CountDownLatch(1);
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
doAnswer(new Answer<Void>() {
@@ -284,17 +285,13 @@ public class SparkSaslSuite {
public Void answer(InvocationOnMock invocation) {
response.set((ManagedBuffer) invocation.getArguments()[1]);
response.get().retain();
- synchronized (lock) {
- lock.notifyAll();
- }
+ lock.countDown();
return null;
}
}).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
- synchronized (lock) {
- ctx.client.fetchChunk(0, 0, callback);
- lock.wait(10 * 1000);
- }
+ ctx.client.fetchChunk(0, 0, callback);
+ lock.await(10, TimeUnit.SECONDS);
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));