From 4e81783e92f464d479baaf93eccc3adb1496989a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 12:58:18 -0800 Subject: [SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are cleaned up. This change does a couple of different things to make sure that the RpcEnv-level code and the network library agree about the status of outstanding RPCs. For RPCs that do not expect a reply ("RpcEnv.send"), support for one way messages (hello CORBA!) was added to the network layer. This is a "fire and forget" message that does not require any state to be kept by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed anymore. For RPCs that do expect a reply ("RpcEnv.ask"), the network library now returns the internal RPC id; if the RpcEnv layer decides to time out the RPC before the network layer does, it now asks the TransportClient to forget about the RPC, so that if the network-level timeout occurs, the client is not killed. As part of implementing the above, I cleaned up some of the code in the netty rpc backend, removing types that were not necessary and factoring out some common code. Of interest is a slight change in the exceptions when posting messages to a stopped RpcEnv; that's mostly to avoid nasty error messages from the local-cluster backend when shutting down, which pollutes the terminal output. Author: Marcelo Vanzin Closes #9917 from vanzin/SPARK-11866. --- .../spark/network/client/TransportClient.java | 34 +++++++++- .../org/apache/spark/network/protocol/Message.java | 4 +- .../spark/network/protocol/MessageDecoder.java | 3 + .../spark/network/protocol/OneWayMessage.java | 75 ++++++++++++++++++++++ .../apache/spark/network/sasl/SaslRpcHandler.java | 5 ++ .../apache/spark/network/server/RpcHandler.java | 36 +++++++++++ .../network/server/TransportRequestHandler.java | 18 +++++- .../org/apache/spark/network/ProtocolSuite.java | 2 + .../apache/spark/network/RpcIntegrationSuite.java | 31 +++++++++ .../apache/spark/network/sasl/SparkSaslSuite.java | 9 +++ 10 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java (limited to 'network') diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 876fcd8467..8a58e7b245 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -25,6 +25,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; @@ -205,8 +207,12 @@ public class TransportClient implements Closeable { /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -235,6 +241,8 @@ public class TransportClient implements Closeable { } } }); + + return requestId; } /** @@ -265,11 +273,35 @@ public class TransportClient implements Closeable { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(byte[] message) { + channel.writeAndFlush(new OneWayMessage(message)); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + /** Mark this channel as having timed out. */ public void timeOut() { this.timedOut = true; } + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d01598c20f..39afd03db6 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -28,7 +28,8 @@ public interface Message extends Encodable { public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8); + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9); private final byte id; @@ -55,6 +56,7 @@ public interface Message extends Encodable { case 6: return StreamRequest; case 7: return StreamResponse; case 8: return StreamFailure; + case 9: return OneWayMessage; default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3c04048f38..074780f2b9 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,9 @@ public final class MessageDecoder extends MessageToMessageDecoder { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + case StreamRequest: return StreamRequest.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 0000000000..95a0270be3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage implements RequestMessage { + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public OneWayMessage(byte[] message) { + this.message = message; + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + return Encoders.ByteArrays.encodedLength(message); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.ByteArrays.encode(buf, message); + } + + public static OneWayMessage decode(ByteBuf buf) { + byte[] message = Encoders.ByteArrays.decode(buf); + return new OneWayMessage(message); + } + + @Override + public int hashCode() { + return Arrays.hashCode(message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 7033adb9ca..830db94b89 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -108,6 +108,11 @@ class SaslRpcHandler extends RpcHandler { } } + @Override + public void receive(TransportClient client, byte[] message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index dbb7f95f55..65109ddfe1 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,9 @@ package org.apache.spark.network.server; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -24,6 +27,9 @@ import org.apache.spark.network.client.TransportClient; * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -47,6 +53,19 @@ public abstract class RpcHandler { */ public abstract StreamManager getStreamManager(); + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, byte[] message) { + receive(client, message, ONE_WAY_CALLBACK); + } + /** * Invoked when the connection associated with the given client has been invalidated. * No further requests will come from this client. @@ -54,4 +73,21 @@ public abstract class RpcHandler { public void connectionTerminated(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(byte[] response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4f67bd573b..db18ea77d1 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -27,13 +28,14 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; @@ -95,6 +97,8 @@ public class TransportRequestHandler extends MessageHandler { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); } else { @@ -156,6 +160,14 @@ public class TransportRequestHandler extends MessageHandler { } } + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.message); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 22b451fc0e..1aa20900ff 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; @@ -84,6 +85,7 @@ public class ProtocolSuite { testClientToServer(new RpcRequest(12345, new byte[0])); testClientToServer(new RpcRequest(12345, new byte[100])); testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new byte[100])); } @Test diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8eb56bdd98..88fa2258bb 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,9 +17,11 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -46,6 +48,7 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { @@ -64,12 +67,19 @@ public class RpcIntegrationSuite { } } + @Override + public void receive(TransportClient client, byte[] message) { + String msg = new String(message, Charsets.UTF_8); + oneWayMsgs.add(msg); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -158,6 +168,27 @@ public class RpcIntegrationSuite { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(message.getBytes(Charsets.UTF_8)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index b146899670..a6f180bc40 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -21,6 +21,7 @@ import static org.junit.Assert.*; import static org.mockito.Mockito.*; import java.io.File; +import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -353,6 +354,14 @@ public class SparkSaslSuite { verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); } + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; -- cgit v1.2.3