aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/TransportContext.java22
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java4
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java4
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java11
-rw-r--r--common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java8
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala16
6 files changed, 38 insertions, 27 deletions
diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index 5b69e2bb03..37ba543380 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -62,8 +62,20 @@ public class TransportContext {
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
- private final MessageEncoder encoder;
- private final MessageDecoder decoder;
+ /**
+ * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
+ * before switching the current context class loader to ExecutorClassLoader.
+ *
+ * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
+ * implementation calls "Class.forName" to check if this calls is already generated. If the
+ * following two objects are created in "ExecutorClassLoader.findClass", it will cause
+ * "ClassCircularityError". This is because loading this Netty generated class will call
+ * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
+ * RPC to load it and cause to load the non-exist matcher class again. JVM will report
+ * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
+ */
+ private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+ private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this(conf, rpcHandler, false);
@@ -75,8 +87,6 @@ public class TransportContext {
boolean closeIdleConnections) {
this.conf = conf;
this.rpcHandler = rpcHandler;
- this.encoder = new MessageEncoder();
- this.decoder = new MessageDecoder();
this.closeIdleConnections = closeIdleConnections;
}
@@ -135,9 +145,9 @@ public class TransportContext {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
- .addLast("encoder", encoder)
+ .addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
- .addLast("decoder", decoder)
+ .addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index f0956438ad..39a7495828 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ public static final MessageDecoder INSTANCE = new MessageDecoder();
+
+ private MessageDecoder() {}
+
@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
Message.Type msgType = Message.Type.decode(in);
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 276f16637e..997f74e1a2 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+ public static final MessageEncoder INSTANCE = new MessageEncoder();
+
+ private MessageEncoder() {}
+
/***
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index c6ccae18b5..56782a8327 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -18,7 +18,7 @@
package org.apache.spark.network.server;
import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger;
@@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportResponseHandler;
-import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ResponseMessage;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
@@ -48,7 +47,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
* timeout if the client is continuously sending but getting no responses, for simplicity.
*/
-public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
private final TransportClient client;
@@ -114,11 +113,13 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
}
@Override
- public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
+ public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
- } else {
+ } else if (request instanceof ResponseMessage) {
responseHandler.handle((ResponseMessage) request);
+ } else {
+ ctx.fireChannelRead(request);
}
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 6c8dd742f4..bb1c40c4b0 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -49,11 +49,11 @@ import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite {
private void testServerToClient(Message msg) {
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
- new MessageEncoder());
+ MessageEncoder.INSTANCE);
serverChannel.writeOutbound(msg);
EmbeddedChannel clientChannel = new EmbeddedChannel(
- NettyUtils.createFrameDecoder(), new MessageDecoder());
+ NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!serverChannel.outboundMessages().isEmpty()) {
clientChannel.writeInbound(serverChannel.readOutbound());
@@ -65,11 +65,11 @@ public class ProtocolSuite {
private void testClientToServer(Message msg) {
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
- new MessageEncoder());
+ MessageEncoder.INSTANCE);
clientChannel.writeOutbound(msg);
EmbeddedChannel serverChannel = new EmbeddedChannel(
- NettyUtils.createFrameDecoder(), new MessageDecoder());
+ NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
while (!clientChannel.outboundMessages().isEmpty()) {
serverChannel.writeInbound(clientChannel.readOutbound());
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c225e1a0cc..fe6fe6aa4f 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2608,12 +2608,8 @@ private[util] object CallerContext extends Logging {
val callerContextSupported: Boolean = {
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
try {
- // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
- // master Maven build, so do not use it before resolving SPARK-17714.
- // scalastyle:off classforname
- Class.forName("org.apache.hadoop.ipc.CallerContext")
- Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
- // scalastyle:on classforname
+ Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+ Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
true
} catch {
case _: ClassNotFoundException =>
@@ -2688,12 +2684,8 @@ private[spark] class CallerContext(
def setCurrentContext(): Unit = {
if (CallerContext.callerContextSupported) {
try {
- // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
- // master Maven build, so do not use it before resolving SPARK-17714.
- // scalastyle:off classforname
- val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
- val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
- // scalastyle:on classforname
+ val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+ val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
val hdfsContext = builder.getMethod("build").invoke(builderInst)
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)