From 9e01dcc6446f8648e61062f8afe62589b9d4b5ab Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Feb 2016 17:25:07 -0800 Subject: [SPARK-13529][BUILD] Move network/* modules into common/network-* ## What changes were proposed in this pull request? As the title says, this moves the three modules currently in network/ into common/network-*. This removes one top level, non-user-facing folder. ## How was this patch tested? Compilation and existing tests. We should run both SBT and Maven. Author: Reynold Xin Closes #11409 from rxin/SPARK-13529. --- network/common/pom.xml | 103 ----- .../org/apache/spark/network/TransportContext.java | 166 ------- .../network/buffer/FileSegmentManagedBuffer.java | 154 ------- .../spark/network/buffer/LazyFileRegion.java | 111 ----- .../apache/spark/network/buffer/ManagedBuffer.java | 75 ---- .../spark/network/buffer/NettyManagedBuffer.java | 76 ---- .../spark/network/buffer/NioManagedBuffer.java | 75 ---- .../network/client/ChunkFetchFailureException.java | 31 -- .../network/client/ChunkReceivedCallback.java | 47 -- .../spark/network/client/RpcResponseCallback.java | 32 -- .../spark/network/client/StreamCallback.java | 40 -- .../spark/network/client/StreamInterceptor.java | 86 ---- .../spark/network/client/TransportClient.java | 321 -------------- .../network/client/TransportClientBootstrap.java | 34 -- .../network/client/TransportClientFactory.java | 264 ------------ .../network/client/TransportResponseHandler.java | 251 ----------- .../spark/network/protocol/AbstractMessage.java | 54 --- .../network/protocol/AbstractResponseMessage.java | 32 -- .../spark/network/protocol/ChunkFetchFailure.java | 76 ---- .../spark/network/protocol/ChunkFetchRequest.java | 71 --- .../spark/network/protocol/ChunkFetchSuccess.java | 89 ---- .../apache/spark/network/protocol/Encodable.java | 41 -- .../apache/spark/network/protocol/Encoders.java | 92 ---- .../org/apache/spark/network/protocol/Message.java | 73 ---- .../spark/network/protocol/MessageDecoder.java | 82 ---- .../spark/network/protocol/MessageEncoder.java | 93 ---- .../spark/network/protocol/MessageWithHeader.java | 135 ------ .../spark/network/protocol/OneWayMessage.java | 80 ---- .../spark/network/protocol/RequestMessage.java | 25 -- .../spark/network/protocol/ResponseMessage.java | 25 -- .../apache/spark/network/protocol/RpcFailure.java | 74 ---- .../apache/spark/network/protocol/RpcRequest.java | 87 ---- .../apache/spark/network/protocol/RpcResponse.java | 87 ---- .../spark/network/protocol/StreamChunkId.java | 73 ---- .../spark/network/protocol/StreamFailure.java | 80 ---- .../spark/network/protocol/StreamRequest.java | 78 ---- .../spark/network/protocol/StreamResponse.java | 92 ---- .../spark/network/sasl/SaslClientBootstrap.java | 109 ----- .../apache/spark/network/sasl/SaslEncryption.java | 291 ------------- .../spark/network/sasl/SaslEncryptionBackend.java | 33 -- .../org/apache/spark/network/sasl/SaslMessage.java | 78 ---- .../apache/spark/network/sasl/SaslRpcHandler.java | 158 ------- .../spark/network/sasl/SaslServerBootstrap.java | 49 --- .../apache/spark/network/sasl/SecretKeyHolder.java | 35 -- .../apache/spark/network/sasl/SparkSaslClient.java | 162 ------- .../apache/spark/network/sasl/SparkSaslServer.java | 200 --------- .../spark/network/server/MessageHandler.java | 39 -- .../spark/network/server/NoOpRpcHandler.java | 40 -- .../network/server/OneForOneStreamManager.java | 143 ------- .../apache/spark/network/server/RpcHandler.java | 100 ----- .../apache/spark/network/server/StreamManager.java | 86 ---- .../network/server/TransportChannelHandler.java | 163 ------- .../network/server/TransportRequestHandler.java | 209 --------- .../spark/network/server/TransportServer.java | 151 ------- .../network/server/TransportServerBootstrap.java | 36 -- .../network/util/ByteArrayWritableChannel.java | 69 --- .../org/apache/spark/network/util/ByteUnit.java | 67 --- .../apache/spark/network/util/ConfigProvider.java | 52 --- .../java/org/apache/spark/network/util/IOMode.java | 27 -- .../org/apache/spark/network/util/JavaUtils.java | 303 ------------- .../spark/network/util/LimitedInputStream.java | 105 ----- .../spark/network/util/MapConfigProvider.java | 41 -- .../org/apache/spark/network/util/NettyUtils.java | 139 ------ .../network/util/SystemPropertyConfigProvider.java | 34 -- .../apache/spark/network/util/TransportConf.java | 169 -------- .../spark/network/util/TransportFrameDecoder.java | 227 ---------- .../spark/network/ChunkFetchIntegrationSuite.java | 244 ----------- .../org/apache/spark/network/ProtocolSuite.java | 127 ------ .../network/RequestTimeoutIntegrationSuite.java | 288 ------------- .../apache/spark/network/RpcIntegrationSuite.java | 215 ---------- .../java/org/apache/spark/network/StreamSuite.java | 349 --------------- .../apache/spark/network/TestManagedBuffer.java | 109 ----- .../java/org/apache/spark/network/TestUtils.java | 30 -- .../spark/network/TransportClientFactorySuite.java | 214 --------- .../network/TransportResponseHandlerSuite.java | 146 ------- .../network/protocol/MessageWithHeaderSuite.java | 157 ------- .../apache/spark/network/sasl/SparkSaslSuite.java | 476 --------------------- .../server/OneForOneStreamManagerSuite.java | 50 --- .../network/util/TransportFrameDecoderSuite.java | 258 ----------- network/common/src/test/resources/log4j.properties | 27 -- network/shuffle/pom.xml | 101 ----- .../spark/network/sasl/ShuffleSecretManager.java | 97 ----- .../network/shuffle/BlockFetchingListener.java | 36 -- .../shuffle/ExternalShuffleBlockHandler.java | 140 ------ .../shuffle/ExternalShuffleBlockResolver.java | 449 ------------------- .../network/shuffle/ExternalShuffleClient.java | 154 ------- .../network/shuffle/OneForOneBlockFetcher.java | 129 ------ .../network/shuffle/RetryingBlockFetcher.java | 234 ---------- .../spark/network/shuffle/ShuffleClient.java | 44 -- .../shuffle/mesos/MesosExternalShuffleClient.java | 73 ---- .../shuffle/protocol/BlockTransferMessage.java | 81 ---- .../shuffle/protocol/ExecutorShuffleInfo.java | 94 ---- .../spark/network/shuffle/protocol/OpenBlocks.java | 90 ---- .../network/shuffle/protocol/RegisterExecutor.java | 94 ---- .../network/shuffle/protocol/StreamHandle.java | 81 ---- .../network/shuffle/protocol/UploadBlock.java | 117 ----- .../shuffle/protocol/mesos/RegisterDriver.java | 63 --- .../spark/network/sasl/SaslIntegrationSuite.java | 294 ------------- .../shuffle/BlockTransferMessagesSuite.java | 44 -- .../shuffle/ExternalShuffleBlockHandlerSuite.java | 127 ------ .../shuffle/ExternalShuffleBlockResolverSuite.java | 156 ------- .../shuffle/ExternalShuffleCleanupSuite.java | 149 ------- .../shuffle/ExternalShuffleIntegrationSuite.java | 301 ------------- .../shuffle/ExternalShuffleSecuritySuite.java | 124 ------ .../shuffle/OneForOneBlockFetcherSuite.java | 176 -------- .../network/shuffle/RetryingBlockFetcherSuite.java | 313 -------------- .../network/shuffle/TestShuffleDataContext.java | 117 ----- network/yarn/pom.xml | 148 ------- .../spark/network/yarn/YarnShuffleService.java | 224 ---------- .../network/yarn/util/HadoopConfigProvider.java | 42 -- 110 files changed, 13702 deletions(-) delete mode 100644 network/common/pom.xml delete mode 100644 network/common/src/main/java/org/apache/spark/network/TransportContext.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClient.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Message.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/StreamManager.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/TransportServer.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/IOMode.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/TransportConf.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/StreamSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/TestUtils.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java delete mode 100644 network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java delete mode 100644 network/common/src/test/resources/log4j.properties delete mode 100644 network/shuffle/pom.xml delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java delete mode 100644 network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java delete mode 100644 network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java delete mode 100644 network/yarn/pom.xml delete mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java delete mode 100644 network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java (limited to 'network') diff --git a/network/common/pom.xml b/network/common/pom.xml deleted file mode 100644 index bd507c2cb6..0000000000 --- a/network/common/pom.xml +++ /dev/null @@ -1,103 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-common_2.11 - jar - Spark Project Networking - http://spark.apache.org/ - - network-common - - - - - - io.netty - netty-all - - - - - org.slf4j - slf4j-api - provided - - - com.google.code.findbugs - jsr305 - - - com.google.guava - guava - compile - - - - - log4j - log4j - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - org.mockito - mockito-core - test - - - org.slf4j - slf4j-log4j12 - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - org.apache.maven.plugins - maven-jar-plugin - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - - - diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java deleted file mode 100644 index 238710d172..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.util.List; - -import com.google.common.collect.Lists; -import io.netty.channel.Channel; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.timeout.IdleStateHandler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.MessageDecoder; -import org.apache.spark.network.protocol.MessageEncoder; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportChannelHandler; -import org.apache.spark.network.server.TransportRequestHandler; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; -import org.apache.spark.network.util.TransportFrameDecoder; - -/** - * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to - * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. - * - * There are two communication protocols that the TransportClient provides, control-plane RPCs and - * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the - * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams - * which can be streamed through the data plane in chunks using zero-copy IO. - * - * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each - * channel. As each TransportChannelHandler contains a TransportClient, this enables server - * processes to send messages back to the client on an existing channel. - */ -public class TransportContext { - private final Logger logger = LoggerFactory.getLogger(TransportContext.class); - - private final TransportConf conf; - private final RpcHandler rpcHandler; - private final boolean closeIdleConnections; - - private final MessageEncoder encoder; - private final MessageDecoder decoder; - - public TransportContext(TransportConf conf, RpcHandler rpcHandler) { - this(conf, rpcHandler, false); - } - - public TransportContext( - TransportConf conf, - RpcHandler rpcHandler, - boolean closeIdleConnections) { - this.conf = conf; - this.rpcHandler = rpcHandler; - this.encoder = new MessageEncoder(); - this.decoder = new MessageDecoder(); - this.closeIdleConnections = closeIdleConnections; - } - - /** - * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning - * a new Client. Bootstraps will be executed synchronously, and must run successfully in order - * to create a Client. - */ - public TransportClientFactory createClientFactory(List bootstraps) { - return new TransportClientFactory(this, bootstraps); - } - - public TransportClientFactory createClientFactory() { - return createClientFactory(Lists.newArrayList()); - } - - /** Create a server which will attempt to bind to a specific port. */ - public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, null, port, rpcHandler, bootstraps); - } - - /** Create a server which will attempt to bind to a specific host and port. */ - public TransportServer createServer( - String host, int port, List bootstraps) { - return new TransportServer(this, host, port, rpcHandler, bootstraps); - } - - /** Creates a new server, binding to any available ephemeral port. */ - public TransportServer createServer(List bootstraps) { - return createServer(0, bootstraps); - } - - public TransportServer createServer() { - return createServer(0, Lists.newArrayList()); - } - - public TransportChannelHandler initializePipeline(SocketChannel channel) { - return initializePipeline(channel, rpcHandler); - } - - /** - * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and - * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or - * response messages. - * - * @param channel The channel to initialize. - * @param channelRpcHandler The RPC handler to use for the channel. - * - * @return Returns the created TransportChannelHandler, which includes a TransportClient that can - * be used to communicate on this channel. The TransportClient is directly associated with a - * ChannelHandler to ensure all users of the same channel get the same TransportClient object. - */ - public TransportChannelHandler initializePipeline( - SocketChannel channel, - RpcHandler channelRpcHandler) { - try { - TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); - channel.pipeline() - .addLast("encoder", encoder) - .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) - .addLast("decoder", decoder) - .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) - // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this - // would require more logic to guarantee if this were not part of the same event loop. - .addLast("handler", channelHandler); - return channelHandler; - } catch (RuntimeException e) { - logger.error("Error while initializing Netty pipeline", e); - throw e; - } - } - - /** - * Creates the server- and client-side handler which is used to handle both RequestMessages and - * ResponseMessages. The channel is expected to have been successfully created, though certain - * properties (such as the remoteAddress()) may not be available yet. - */ - private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) { - TransportResponseHandler responseHandler = new TransportResponseHandler(channel); - TransportClient client = new TransportClient(channel, responseHandler); - TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs(), closeIdleConnections); - } - - public TransportConf getConf() { return conf; } -} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java deleted file mode 100644 index 844eff4f4c..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.RandomAccessFile; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; - -import com.google.common.base.Objects; -import com.google.common.io.ByteStreams; -import io.netty.channel.DefaultFileRegion; - -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.LimitedInputStream; -import org.apache.spark.network.util.TransportConf; - -/** - * A {@link ManagedBuffer} backed by a segment in a file. - */ -public final class FileSegmentManagedBuffer extends ManagedBuffer { - private final TransportConf conf; - private final File file; - private final long offset; - private final long length; - - public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) { - this.conf = conf; - this.file = file; - this.offset = offset; - this.length = length; - } - - @Override - public long size() { - return length; - } - - @Override - public ByteBuffer nioByteBuffer() throws IOException { - FileChannel channel = null; - try { - channel = new RandomAccessFile(file, "r").getChannel(); - // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. - if (length < conf.memoryMapBytes()) { - ByteBuffer buf = ByteBuffer.allocate((int) length); - channel.position(offset); - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException(String.format("Reached EOF before filling buffer\n" + - "offset=%s\nfile=%s\nbuf.remaining=%s", - offset, file.getAbsoluteFile(), buf.remaining())); - } - } - buf.flip(); - return buf; - } else { - return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); - } - } catch (IOException e) { - try { - if (channel != null) { - long size = channel.size(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } - throw new IOException("Error in opening " + this, e); - } finally { - JavaUtils.closeQuietly(channel); - } - } - - @Override - public InputStream createInputStream() throws IOException { - FileInputStream is = null; - try { - is = new FileInputStream(file); - ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); - } catch (IOException e) { - try { - if (is != null) { - long size = file.length(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } finally { - JavaUtils.closeQuietly(is); - } - throw new IOException("Error in opening " + this, e); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(is); - throw e; - } - } - - @Override - public ManagedBuffer retain() { - return this; - } - - @Override - public ManagedBuffer release() { - return this; - } - - @Override - public Object convertToNetty() throws IOException { - if (conf.lazyFileDescriptor()) { - return new LazyFileRegion(file, offset, length); - } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); - return new DefaultFileRegion(fileChannel, offset, length); - } - } - - public File getFile() { return file; } - - public long getOffset() { return offset; } - - public long getLength() { return length; } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("file", file) - .add("offset", offset) - .add("length", length) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java deleted file mode 100644 index 162cf6da0d..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.FileInputStream; -import java.io.File; -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; - -import com.google.common.base.Objects; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; - -import org.apache.spark.network.util.JavaUtils; - -/** - * A FileRegion implementation that only creates the file descriptor when the region is being - * transferred. This cannot be used with Epoll because there is no native support for it. - * - * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we - * should push this into Netty so the native Epoll transport can support this feature. - */ -public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { - - private final File file; - private final long position; - private final long count; - - private FileChannel channel; - - private long numBytesTransferred = 0L; - - /** - * @param file file to transfer. - * @param position start position for the transfer. - * @param count number of bytes to transfer starting from position. - */ - public LazyFileRegion(File file, long position, long count) { - this.file = file; - this.position = position; - this.count = count; - } - - @Override - protected void deallocate() { - JavaUtils.closeQuietly(channel); - } - - @Override - public long position() { - return position; - } - - @Override - public long transfered() { - return numBytesTransferred; - } - - @Override - public long count() { - return count; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - if (channel == null) { - channel = new FileInputStream(file).getChannel(); - } - - long count = this.count - position; - if (count < 0 || position < 0) { - throw new IllegalArgumentException( - "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); - } - - if (count == 0) { - return 0L; - } - - long written = channel.transferTo(this.position + position, count, target); - if (written > 0) { - numBytesTransferred += written; - } - return written; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("file", file) - .add("position", position) - .add("count", count) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java deleted file mode 100644 index 1861f8d7fd..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - {@link FileSegmentManagedBuffer}: data backed by part of a file - * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer - * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf - * - * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. - * In that case, if the buffer is going to be passed around to a different thread, retain/release - * should be called. - */ -public abstract class ManagedBuffer { - - /** Number of bytes of the data. */ - public abstract long size(); - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - // TODO: Deprecate this, usage may require expensive memory mapping or allocation. - public abstract ByteBuffer nioByteBuffer() throws IOException; - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - public abstract InputStream createInputStream() throws IOException; - - /** - * Increment the reference count by one if applicable. - */ - public abstract ManagedBuffer retain(); - - /** - * If applicable, decrement the reference count by one and deallocates the buffer if the - * reference count reaches zero. - */ - public abstract ManagedBuffer release(); - - /** - * Convert the buffer into an Netty object, used to write the data out. The return value is either - * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}. - * - * If this method returns a ByteBuf, then that buffer's reference count will be incremented and - * the caller will be responsible for releasing this new reference. - */ - public abstract Object convertToNetty() throws IOException; -} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java deleted file mode 100644 index 4c8802af7a..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufInputStream; - -/** - * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. - */ -public final class NettyManagedBuffer extends ManagedBuffer { - private final ByteBuf buf; - - public NettyManagedBuffer(ByteBuf buf) { - this.buf = buf; - } - - @Override - public long size() { - return buf.readableBytes(); - } - - @Override - public ByteBuffer nioByteBuffer() throws IOException { - return buf.nioBuffer(); - } - - @Override - public InputStream createInputStream() throws IOException { - return new ByteBufInputStream(buf); - } - - @Override - public ManagedBuffer retain() { - buf.retain(); - return this; - } - - @Override - public ManagedBuffer release() { - buf.release(); - return this; - } - - @Override - public Object convertToNetty() throws IOException { - return buf.duplicate().retain(); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("buf", buf) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java deleted file mode 100644 index 631d767715..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBufInputStream; -import io.netty.buffer.Unpooled; - -/** - * A {@link ManagedBuffer} backed by {@link ByteBuffer}. - */ -public class NioManagedBuffer extends ManagedBuffer { - private final ByteBuffer buf; - - public NioManagedBuffer(ByteBuffer buf) { - this.buf = buf; - } - - @Override - public long size() { - return buf.remaining(); - } - - @Override - public ByteBuffer nioByteBuffer() throws IOException { - return buf.duplicate(); - } - - @Override - public InputStream createInputStream() throws IOException { - return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); - } - - @Override - public ManagedBuffer retain() { - return this; - } - - @Override - public ManagedBuffer release() { - return this; - } - - @Override - public Object convertToNetty() throws IOException { - return Unpooled.wrappedBuffer(buf); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("buf", buf) - .toString(); - } -} - diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java deleted file mode 100644 index 1fbdcd6780..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -/** - * General exception caused by a remote exception while fetching a chunk. - */ -public class ChunkFetchFailureException extends RuntimeException { - public ChunkFetchFailureException(String errorMsg, Throwable cause) { - super(errorMsg, cause); - } - - public ChunkFetchFailureException(String errorMsg) { - super(errorMsg); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java deleted file mode 100644 index 519e6cb470..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import org.apache.spark.network.buffer.ManagedBuffer; - -/** - * Callback for the result of a single chunk result. For a single stream, the callbacks are - * guaranteed to be called by the same thread in the same order as the requests for chunks were - * made. - * - * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. - */ -public interface ChunkReceivedCallback { - /** - * Called upon receipt of a particular chunk. - * - * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this - * call returns. You must therefore either retain() the buffer or copy its contents before - * returning. - */ - void onSuccess(int chunkIndex, ManagedBuffer buffer); - - /** - * Called upon failure to fetch a particular chunk. Note that this may actually be called due - * to failure to fetch a prior chunk in this stream. - * - * After receiving a failure, the stream may or may not be valid. The client should not assume - * that the server's side of the stream has been closed. - */ - void onFailure(int chunkIndex, Throwable e); -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java deleted file mode 100644 index 47e93f9846..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.nio.ByteBuffer; - -/** - * Callback for the result of a single RPC. This will be invoked once with either success or - * failure. - */ -public interface RpcResponseCallback { - /** Successful serialized result from server. */ - void onSuccess(ByteBuffer response); - - /** Exception either propagated from server or raised on client side. */ - void onFailure(Throwable e); -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java deleted file mode 100644 index 29e6a30dc1..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.io.IOException; -import java.nio.ByteBuffer; - -/** - * Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be - * called. - *

- * The network library guarantees that a single thread will call these methods at a time, but - * different call may be made by different threads. - */ -public interface StreamCallback { - /** Called upon receipt of stream data. */ - void onData(String streamId, ByteBuffer buf) throws IOException; - - /** Called when all data from the stream has been received. */ - void onComplete(String streamId) throws IOException; - - /** Called if there's an error reading data from the stream. */ - void onFailure(String streamId, Throwable cause) throws IOException; -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java deleted file mode 100644 index 88ba3ccebd..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.util.TransportFrameDecoder; - -/** - * An interceptor that is registered with the frame decoder to feed stream data to a - * callback. - */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { - - private final TransportResponseHandler handler; - private final String streamId; - private final long byteCount; - private final StreamCallback callback; - - private volatile long bytesRead; - - StreamInterceptor( - TransportResponseHandler handler, - String streamId, - long byteCount, - StreamCallback callback) { - this.handler = handler; - this.streamId = streamId; - this.byteCount = byteCount; - this.callback = callback; - this.bytesRead = 0; - } - - @Override - public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); - callback.onFailure(streamId, cause); - } - - @Override - public void channelInactive() throws Exception { - handler.deactivateStream(); - callback.onFailure(streamId, new ClosedChannelException()); - } - - @Override - public boolean handle(ByteBuf buf) throws Exception { - int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); - ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer(); - - int available = nioBuffer.remaining(); - callback.onData(streamId, nioBuffer); - bytesRead += available; - if (bytesRead > byteCount) { - RuntimeException re = new IllegalStateException(String.format( - "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); - callback.onFailure(streamId, re); - handler.deactivateStream(); - throw re; - } else if (bytesRead == byteCount) { - handler.deactivateStream(); - callback.onComplete(streamId); - } - - return bytesRead != byteCount; - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java deleted file mode 100644 index e15f096d36..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ /dev/null @@ -1,321 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.io.Closeable; -import java.io.IOException; -import java.net.SocketAddress; -import java.nio.ByteBuffer; -import java.util.UUID; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Objects; -import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; -import com.google.common.util.concurrent.SettableFuture; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.util.NettyUtils; - -/** - * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow - * efficient transfer of a large amount of data, broken up into chunks with size ranging from - * hundreds of KB to a few MB. - * - * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), - * the actual setup of the streams is done outside the scope of the transport layer. The convenience - * method "sendRPC" is provided to enable control plane communication between the client and server - * to perform this setup. - * - * For example, a typical workflow might be: - * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 - * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) - * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) - * ... - * client.sendRPC(new CloseStream(100)) - * - * Construct an instance of TransportClient using {@link TransportClientFactory}. A single - * TransportClient may be used for multiple streams, but any given stream must be restricted to a - * single client, in order to avoid out-of-order responses. - * - * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is - * responsible for handling responses from the server. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -public class TransportClient implements Closeable { - private final Logger logger = LoggerFactory.getLogger(TransportClient.class); - - private final Channel channel; - private final TransportResponseHandler handler; - @Nullable private String clientId; - private volatile boolean timedOut; - - public TransportClient(Channel channel, TransportResponseHandler handler) { - this.channel = Preconditions.checkNotNull(channel); - this.handler = Preconditions.checkNotNull(handler); - this.timedOut = false; - } - - public Channel getChannel() { - return channel; - } - - public boolean isActive() { - return !timedOut && (channel.isOpen() || channel.isActive()); - } - - public SocketAddress getSocketAddress() { - return channel.remoteAddress(); - } - - /** - * Returns the ID used by the client to authenticate itself when authentication is enabled. - * - * @return The client ID, or null if authentication is disabled. - */ - public String getClientId() { - return clientId; - } - - /** - * Sets the authenticated client ID. This is meant to be used by the authentication layer. - * - * Trying to set a different client ID after it's been set will result in an exception. - */ - public void setClientId(String id) { - Preconditions.checkState(clientId == null, "Client ID has already been set."); - this.clientId = id; - } - - /** - * Requests a single chunk from the remote side, from the pre-negotiated streamId. - * - * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though - * some streams may not support this. - * - * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed - * to be returned in the same order that they were requested, assuming only a single - * TransportClient is used to fetch the chunks. - * - * @param streamId Identifier that refers to a stream in the remote StreamManager. This should - * be agreed upon by client and server beforehand. - * @param chunkIndex 0-based index of the chunk to fetch - * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. - */ - public void fetchChunk( - long streamId, - final int chunkIndex, - final ChunkReceivedCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); - final long startTime = System.currentTimeMillis(); - logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); - - final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, - timeTaken); - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - serverAddr, future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - } - }); - } - - /** - * Request to stream the data with the given stream ID from the remote end. - * - * @param streamId The stream to fetch. - * @param callback Object to call with the stream data. - */ - public void stream(final String streamId, final StreamCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); - final long startTime = System.currentTimeMillis(); - logger.debug("Sending stream request for {} to {}", streamId, serverAddr); - - // Need to synchronize here so that the callback is added to the queue and the RPC is - // written to the socket atomically, so that callbacks are called in the right order - // when responses arrive. - synchronized (this) { - handler.addStreamCallback(callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, - timeTaken); - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - serverAddr, future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - } - }); - } - } - - /** - * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked - * with the server's response or upon any failure. - * - * @param message The message to send. - * @param callback Callback to handle the RPC's reply. - * @return The RPC's id. - */ - public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); - final long startTime = System.currentTimeMillis(); - logger.trace("Sending RPC to {}", serverAddr); - - final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); - handler.addRpcRequest(requestId, callback); - - channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - serverAddr, future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - } - }); - - return requestId; - } - - /** - * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to - * a specified timeout for a response. - */ - public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); - - sendRpc(message, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - result.set(response); - } - - @Override - public void onFailure(Throwable e) { - result.setException(e); - } - }); - - try { - return result.get(timeoutMs, TimeUnit.MILLISECONDS); - } catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); - } catch (Exception e) { - throw Throwables.propagate(e); - } - } - - /** - * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the - * message, and no delivery guarantees are made. - * - * @param message The message to send. - */ - public void send(ByteBuffer message) { - channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); - } - - /** - * Removes any state associated with the given RPC. - * - * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}. - */ - public void removeRpcRequest(long requestId) { - handler.removeRpcRequest(requestId); - } - - /** Mark this channel as having timed out. */ - public void timeOut() { - this.timedOut = true; - } - - @VisibleForTesting - public TransportResponseHandler getHandler() { - return handler; - } - - @Override - public void close() { - // close is a local operation and should finish with milliseconds; timeout just to be safe - channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("remoteAdress", channel.remoteAddress()) - .add("clientId", clientId) - .add("isActive", isActive()) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java deleted file mode 100644 index eaae2ee043..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import io.netty.channel.Channel; - -/** - * A bootstrap which is executed on a TransportClient before it is returned to the user. - * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- - * connection basis. - * - * Since connections (and TransportClients) are reused as much as possible, it is generally - * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with - * the JVM itself. - */ -public interface TransportClientBootstrap { - /** Performs the bootstrapping operation, throwing an exception on failure. */ - void doBootstrap(TransportClient client, Channel channel) throws RuntimeException; -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java deleted file mode 100644 index 61bafc8380..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.io.Closeable; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.List; -import java.util.Random; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; - -import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; -import com.google.common.collect.Lists; -import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.socket.SocketChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.server.TransportChannelHandler; -import org.apache.spark.network.util.IOMode; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Factory for creating {@link TransportClient}s by using createClient. - * - * The factory maintains a connection pool to other hosts and should return the same - * TransportClient for the same remote host. It also shares a single worker thread pool for - * all TransportClients. - * - * TransportClients will be reused whenever possible. Prior to completing the creation of a new - * TransportClient, all given {@link TransportClientBootstrap}s will be run. - */ -public class TransportClientFactory implements Closeable { - - /** A simple data structure to track the pool of clients between two peer nodes. */ - private static class ClientPool { - TransportClient[] clients; - Object[] locks; - - public ClientPool(int size) { - clients = new TransportClient[size]; - locks = new Object[size]; - for (int i = 0; i < size; i++) { - locks[i] = new Object(); - } - } - } - - private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); - - private final TransportContext context; - private final TransportConf conf; - private final List clientBootstraps; - private final ConcurrentHashMap connectionPool; - - /** Random number generator for picking connections between peers. */ - private final Random rand; - private final int numConnectionsPerPeer; - - private final Class socketChannelClass; - private EventLoopGroup workerGroup; - private PooledByteBufAllocator pooledAllocator; - - public TransportClientFactory( - TransportContext context, - List clientBootstraps) { - this.context = Preconditions.checkNotNull(context); - this.conf = context.getConf(); - this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); - this.connectionPool = new ConcurrentHashMap(); - this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); - this.rand = new Random(); - - IOMode ioMode = IOMode.valueOf(conf.ioMode()); - this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); - // TODO: Make thread pool name configurable. - this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); - this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( - conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); - } - - /** - * Create a {@link TransportClient} connecting to the given remote host / port. - * - * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer) - * and randomly picks one to use. If no client was previously created in the randomly selected - * spot, this function creates a new client and places it there. - * - * Prior to the creation of a new TransportClient, we will execute all - * {@link TransportClientBootstrap}s that are registered with this factory. - * - * This blocks until a connection is successfully established and fully bootstrapped. - * - * Concurrency: This method is safe to call from multiple threads. - */ - public TransportClient createClient(String remoteHost, int remotePort) throws IOException { - // Get connection from the connection pool first. - // If it is not found or not active, create a new one. - final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - - // Create the ClientPool if we don't have it yet. - ClientPool clientPool = connectionPool.get(address); - if (clientPool == null) { - connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer)); - clientPool = connectionPool.get(address); - } - - int clientIndex = rand.nextInt(numConnectionsPerPeer); - TransportClient cachedClient = clientPool.clients[clientIndex]; - - if (cachedClient != null && cachedClient.isActive()) { - // Make sure that the channel will not timeout by updating the last use time of the - // handler. Then check that the client is still alive, in case it timed out before - // this code was able to update things. - TransportChannelHandler handler = cachedClient.getChannel().pipeline() - .get(TransportChannelHandler.class); - synchronized (handler) { - handler.getResponseHandler().updateTimeOfLastRequest(); - } - - if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; - } - } - - // If we reach here, we don't have an existing connection open. Let's create a new one. - // Multiple threads might race here to create new connections. Keep only one of them active. - synchronized (clientPool.locks[clientIndex]) { - cachedClient = clientPool.clients[clientIndex]; - - if (cachedClient != null) { - if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; - } else { - logger.info("Found inactive connection to {}, creating a new one.", address); - } - } - clientPool.clients[clientIndex] = createClient(address); - return clientPool.clients[clientIndex]; - } - } - - /** - * Create a completely new {@link TransportClient} to the given remote host / port. - * This connection is not pooled. - * - * As with {@link #createClient(String, int)}, this method is blocking. - */ - public TransportClient createUnmanagedClient(String remoteHost, int remotePort) - throws IOException { - final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - return createClient(address); - } - - /** Create a completely new {@link TransportClient} to the remote address. */ - private TransportClient createClient(InetSocketAddress address) throws IOException { - logger.debug("Creating new connection to " + address); - - Bootstrap bootstrap = new Bootstrap(); - bootstrap.group(workerGroup) - .channel(socketChannelClass) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) - .option(ChannelOption.ALLOCATOR, pooledAllocator); - - final AtomicReference clientRef = new AtomicReference(); - final AtomicReference channelRef = new AtomicReference(); - - bootstrap.handler(new ChannelInitializer() { - @Override - public void initChannel(SocketChannel ch) { - TransportChannelHandler clientHandler = context.initializePipeline(ch); - clientRef.set(clientHandler.getClient()); - channelRef.set(ch); - } - }); - - // Connect to the remote server - long preConnect = System.nanoTime(); - ChannelFuture cf = bootstrap.connect(address); - if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new IOException( - String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); - } else if (cf.cause() != null) { - throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); - } - - TransportClient client = clientRef.get(); - Channel channel = channelRef.get(); - assert client != null : "Channel future completed successfully with null client"; - - // Execute any client bootstraps synchronously before marking the Client as successful. - long preBootstrap = System.nanoTime(); - logger.debug("Connection to {} successful, running bootstraps...", address); - try { - for (TransportClientBootstrap clientBootstrap : clientBootstraps) { - clientBootstrap.doBootstrap(client, channel); - } - } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala - long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; - logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e); - client.close(); - throw Throwables.propagate(e); - } - long postBootstrap = System.nanoTime(); - - logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", - address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); - - return client; - } - - /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - @Override - public void close() { - // Go through all clients and close them if they are active. - for (ClientPool clientPool : connectionPool.values()) { - for (int i = 0; i < clientPool.clients.length; i++) { - TransportClient client = clientPool.clients[i]; - if (client != null) { - clientPool.clients[i] = null; - JavaUtils.closeQuietly(client); - } - } - } - connectionPool.clear(); - - if (workerGroup != null) { - workerGroup.shutdownGracefully(); - workerGroup = null; - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java deleted file mode 100644 index f0e2004d2d..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import java.io.IOException; -import java.util.Map; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicLong; - -import com.google.common.annotations.VisibleForTesting; -import io.netty.channel.Channel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.server.MessageHandler; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportFrameDecoder; - -/** - * Handler that processes server responses, in response to requests issued from a - * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks). - * - * Concurrency: thread safe and can be called from multiple threads. - */ -public class TransportResponseHandler extends MessageHandler { - private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class); - - private final Channel channel; - - private final Map outstandingFetches; - - private final Map outstandingRpcs; - - private final Queue streamCallbacks; - private volatile boolean streamActive; - - /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ - private final AtomicLong timeOfLastRequestNs; - - public TransportResponseHandler(Channel channel) { - this.channel = channel; - this.outstandingFetches = new ConcurrentHashMap(); - this.outstandingRpcs = new ConcurrentHashMap(); - this.streamCallbacks = new ConcurrentLinkedQueue(); - this.timeOfLastRequestNs = new AtomicLong(0); - } - - public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - updateTimeOfLastRequest(); - outstandingFetches.put(streamChunkId, callback); - } - - public void removeFetchRequest(StreamChunkId streamChunkId) { - outstandingFetches.remove(streamChunkId); - } - - public void addRpcRequest(long requestId, RpcResponseCallback callback) { - updateTimeOfLastRequest(); - outstandingRpcs.put(requestId, callback); - } - - public void removeRpcRequest(long requestId) { - outstandingRpcs.remove(requestId); - } - - public void addStreamCallback(StreamCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(callback); - } - - @VisibleForTesting - public void deactivateStream() { - streamActive = false; - } - - /** - * Fire the failure callback for all outstanding requests. This is called when we have an - * uncaught exception or pre-mature connection termination. - */ - private void failOutstandingRequests(Throwable cause) { - for (Map.Entry entry : outstandingFetches.entrySet()) { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); - } - for (Map.Entry entry : outstandingRpcs.entrySet()) { - entry.getValue().onFailure(cause); - } - - // It's OK if new fetches appear, as they will fail immediately. - outstandingFetches.clear(); - outstandingRpcs.clear(); - } - - @Override - public void channelActive() { - } - - @Override - public void channelInactive() { - if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); - logger.error("Still have {} requests outstanding when connection from {} is closed", - numOutstandingRequests(), remoteAddress); - failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); - } - } - - @Override - public void exceptionCaught(Throwable cause) { - if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); - logger.error("Still have {} requests outstanding when connection from {} is closed", - numOutstandingRequests(), remoteAddress); - failOutstandingRequests(cause); - } - } - - @Override - public void handle(ResponseMessage message) throws Exception { - String remoteAddress = NettyUtils.getRemoteAddress(channel); - if (message instanceof ChunkFetchSuccess) { - ChunkFetchSuccess resp = (ChunkFetchSuccess) message; - ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); - if (listener == null) { - logger.warn("Ignoring response for block {} from {} since it is not outstanding", - resp.streamChunkId, remoteAddress); - resp.body().release(); - } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); - resp.body().release(); - } - } else if (message instanceof ChunkFetchFailure) { - ChunkFetchFailure resp = (ChunkFetchFailure) message; - ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); - if (listener == null) { - logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", - resp.streamChunkId, remoteAddress, resp.errorString); - } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( - "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); - } - } else if (message instanceof RpcResponse) { - RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); - if (listener == null) { - logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.body().size()); - } else { - outstandingRpcs.remove(resp.requestId); - try { - listener.onSuccess(resp.body().nioByteBuffer()); - } finally { - resp.body().release(); - } - } - } else if (message instanceof RpcFailure) { - RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); - if (listener == null) { - logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", - resp.requestId, remoteAddress, resp.errorString); - } else { - outstandingRpcs.remove(resp.requestId); - listener.onFailure(new RuntimeException(resp.errorString)); - } - } else if (message instanceof StreamResponse) { - StreamResponse resp = (StreamResponse) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { - if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); - } - } else { - try { - callback.onComplete(resp.streamId); - } catch (Exception e) { - logger.warn("Error in stream handler onComplete().", e); - } - } - } else { - logger.error("Could not find callback for StreamResponse."); - } - } else if (message instanceof StreamFailure) { - StreamFailure resp = (StreamFailure) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { - try { - callback.onFailure(resp.streamId, new RuntimeException(resp.error)); - } catch (IOException ioe) { - logger.warn("Error in stream failure handler.", ioe); - } - } else { - logger.warn("Stream failure with unknown callback: {}", resp.error); - } - } else { - throw new IllegalStateException("Unknown response type: " + message.type()); - } - } - - /** Returns total number of outstanding requests (fetch requests + rpcs) */ - public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + - (streamActive ? 1 : 0); - } - - /** Returns the time in nanoseconds of when the last request was sent out. */ - public long getTimeOfLastRequestNs() { - return timeOfLastRequestNs.get(); - } - - /** Updates the time of the last request to the current system time. */ - public void updateTimeOfLastRequest() { - timeOfLastRequestNs.set(System.nanoTime()); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java deleted file mode 100644 index 2924218c2f..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; - -import org.apache.spark.network.buffer.ManagedBuffer; - -/** - * Abstract class for messages which optionally contain a body kept in a separate buffer. - */ -public abstract class AbstractMessage implements Message { - private final ManagedBuffer body; - private final boolean isBodyInFrame; - - protected AbstractMessage() { - this(null, false); - } - - protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; - } - - @Override - public ManagedBuffer body() { - return body; - } - - @Override - public boolean isBodyInFrame() { - return isBodyInFrame; - } - - protected boolean equals(AbstractMessage other) { - return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java deleted file mode 100644 index c362c92fc4..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import org.apache.spark.network.buffer.ManagedBuffer; - -/** - * Abstract class for response messages. - */ -public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { - - protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { - super(body, isBodyInFrame); - } - - public abstract ResponseMessage createFailureResponse(String error); -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java deleted file mode 100644 index 7b28a9a969..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** - * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. - */ -public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { - public final StreamChunkId streamChunkId; - public final String errorString; - - public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { - this.streamChunkId = streamChunkId; - this.errorString = errorString; - } - - @Override - public Type type() { return Type.ChunkFetchFailure; } - - @Override - public int encodedLength() { - return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); - } - - @Override - public void encode(ByteBuf buf) { - streamChunkId.encode(buf); - Encoders.Strings.encode(buf, errorString); - } - - public static ChunkFetchFailure decode(ByteBuf buf) { - StreamChunkId streamChunkId = StreamChunkId.decode(buf); - String errorString = Encoders.Strings.decode(buf); - return new ChunkFetchFailure(streamChunkId, errorString); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamChunkId, errorString); - } - - @Override - public boolean equals(Object other) { - if (other instanceof ChunkFetchFailure) { - ChunkFetchFailure o = (ChunkFetchFailure) other; - return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) - .add("errorString", errorString) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java deleted file mode 100644 index 26d063feb5..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** - * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single - * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). - */ -public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { - public final StreamChunkId streamChunkId; - - public ChunkFetchRequest(StreamChunkId streamChunkId) { - this.streamChunkId = streamChunkId; - } - - @Override - public Type type() { return Type.ChunkFetchRequest; } - - @Override - public int encodedLength() { - return streamChunkId.encodedLength(); - } - - @Override - public void encode(ByteBuf buf) { - streamChunkId.encode(buf); - } - - public static ChunkFetchRequest decode(ByteBuf buf) { - return new ChunkFetchRequest(StreamChunkId.decode(buf)); - } - - @Override - public int hashCode() { - return streamChunkId.hashCode(); - } - - @Override - public boolean equals(Object other) { - if (other instanceof ChunkFetchRequest) { - ChunkFetchRequest o = (ChunkFetchRequest) other; - return streamChunkId.equals(o.streamChunkId); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java deleted file mode 100644 index 94c2ac9b20..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. - * - * Note that the server-side encoding of this messages does NOT include the buffer itself, as this - * may be written by Netty in a more efficient manner (i.e., zero-copy write). - * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. - */ -public final class ChunkFetchSuccess extends AbstractResponseMessage { - public final StreamChunkId streamChunkId; - - public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { - super(buffer, true); - this.streamChunkId = streamChunkId; - } - - @Override - public Type type() { return Type.ChunkFetchSuccess; } - - @Override - public int encodedLength() { - return streamChunkId.encodedLength(); - } - - /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ - @Override - public void encode(ByteBuf buf) { - streamChunkId.encode(buf); - } - - @Override - public ResponseMessage createFailureResponse(String error) { - return new ChunkFetchFailure(streamChunkId, error); - } - - /** Decoding uses the given ByteBuf as our data, and will retain() it. */ - public static ChunkFetchSuccess decode(ByteBuf buf) { - StreamChunkId streamChunkId = StreamChunkId.decode(buf); - buf.retain(); - NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new ChunkFetchSuccess(streamChunkId, managedBuf); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamChunkId, body()); - } - - @Override - public boolean equals(Object other) { - if (other instanceof ChunkFetchSuccess) { - ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && super.equals(o); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamChunkId", streamChunkId) - .add("buffer", body()) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java deleted file mode 100644 index b4e299471b..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import io.netty.buffer.ByteBuf; - -/** - * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are - * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. - * - * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by - * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than - * just copying data from it), then you must retain() the ByteBuf. - * - * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. - */ -public interface Encodable { - /** Number of bytes of the encoded form of this object. */ - int encodedLength(); - - /** - * Serializes this object by writing into the given ByteBuf. - * This method must write exactly encodedLength() bytes. - */ - void encode(ByteBuf buf); -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java deleted file mode 100644 index 9162d0b977..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - - -import com.google.common.base.Charsets; -import io.netty.buffer.ByteBuf; - -/** Provides a canonical set of Encoders for simple types. */ -public class Encoders { - - /** Strings are encoded with their length followed by UTF-8 bytes. */ - public static class Strings { - public static int encodedLength(String s) { - return 4 + s.getBytes(Charsets.UTF_8).length; - } - - public static void encode(ByteBuf buf, String s) { - byte[] bytes = s.getBytes(Charsets.UTF_8); - buf.writeInt(bytes.length); - buf.writeBytes(bytes); - } - - public static String decode(ByteBuf buf) { - int length = buf.readInt(); - byte[] bytes = new byte[length]; - buf.readBytes(bytes); - return new String(bytes, Charsets.UTF_8); - } - } - - /** Byte arrays are encoded with their length followed by bytes. */ - public static class ByteArrays { - public static int encodedLength(byte[] arr) { - return 4 + arr.length; - } - - public static void encode(ByteBuf buf, byte[] arr) { - buf.writeInt(arr.length); - buf.writeBytes(arr); - } - - public static byte[] decode(ByteBuf buf) { - int length = buf.readInt(); - byte[] bytes = new byte[length]; - buf.readBytes(bytes); - return bytes; - } - } - - /** String arrays are encoded with the number of strings followed by per-String encoding. */ - public static class StringArrays { - public static int encodedLength(String[] strings) { - int totalLength = 4; - for (String s : strings) { - totalLength += Strings.encodedLength(s); - } - return totalLength; - } - - public static void encode(ByteBuf buf, String[] strings) { - buf.writeInt(strings.length); - for (String s : strings) { - Strings.encode(buf, s); - } - } - - public static String[] decode(ByteBuf buf) { - int numStrings = buf.readInt(); - String[] strings = new String[numStrings]; - for (int i = 0; i < strings.length; i ++) { - strings[i] = Strings.decode(buf); - } - return strings; - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java deleted file mode 100644 index 66f5b8b3a5..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; - -/** An on-the-wire transmittable message. */ -public interface Message extends Encodable { - /** Used to identify this request type. */ - Type type(); - - /** An optional body for the message. */ - ManagedBuffer body(); - - /** Whether to include the body of the message in the same frame as the message. */ - boolean isBodyInFrame(); - - /** Preceding every serialized Message is its type, which allows us to deserialize it. */ - public static enum Type implements Encodable { - ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), - RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); - - private final byte id; - - private Type(int id) { - assert id < 128 : "Cannot have more than 128 message types"; - this.id = (byte) id; - } - - public byte id() { return id; } - - @Override public int encodedLength() { return 1; } - - @Override public void encode(ByteBuf buf) { buf.writeByte(id); } - - public static Type decode(ByteBuf buf) { - byte id = buf.readByte(); - switch (id) { - case 0: return ChunkFetchRequest; - case 1: return ChunkFetchSuccess; - case 2: return ChunkFetchFailure; - case 3: return RpcRequest; - case 4: return RpcResponse; - case 5: return RpcFailure; - case 6: return StreamRequest; - case 7: return StreamResponse; - case 8: return StreamFailure; - case 9: return OneWayMessage; - case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); - default: throw new IllegalArgumentException("Unknown message type: " + id); - } - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java deleted file mode 100644 index 074780f2b9..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageDecoder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Decoder used by the client side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@ChannelHandler.Sharable -public final class MessageDecoder extends MessageToMessageDecoder { - - private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); - @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - Message.Type msgType = Message.Type.decode(in); - Message decoded = decode(msgType, in); - assert decoded.type() == msgType; - logger.trace("Received message " + msgType + ": " + decoded); - out.add(decoded); - } - - private Message decode(Message.Type msgType, ByteBuf in) { - switch (msgType) { - case ChunkFetchRequest: - return ChunkFetchRequest.decode(in); - - case ChunkFetchSuccess: - return ChunkFetchSuccess.decode(in); - - case ChunkFetchFailure: - return ChunkFetchFailure.decode(in); - - case RpcRequest: - return RpcRequest.decode(in); - - case RpcResponse: - return RpcResponse.decode(in); - - case RpcFailure: - return RpcFailure.decode(in); - - case OneWayMessage: - return OneWayMessage.decode(in); - - case StreamRequest: - return StreamRequest.decode(in); - - case StreamResponse: - return StreamResponse.decode(in); - - case StreamFailure: - return StreamFailure.decode(in); - - default: - throw new IllegalArgumentException("Unexpected message type: " + msgType); - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java deleted file mode 100644 index 664df57fec..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageEncoder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Encoder used by the server side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@ChannelHandler.Sharable -public final class MessageEncoder extends MessageToMessageEncoder { - - private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); - - /*** - * Encodes a Message by invoking its encode() method. For non-data messages, we will add one - * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. - * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the - * data to 'out', in order to enable zero-copy transfer. - */ - @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { - Object body = null; - long bodyLength = 0; - boolean isBodyInFrame = false; - - // If the message has a body, take it out to enable zero-copy transfer for the payload. - if (in.body() != null) { - try { - bodyLength = in.body().size(); - body = in.body().convertToNetty(); - isBodyInFrame = in.isBodyInFrame(); - } catch (Exception e) { - in.body().release(); - if (in instanceof AbstractResponseMessage) { - AbstractResponseMessage resp = (AbstractResponseMessage) in; - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - in, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); - } else { - throw e; - } - return; - } - } - - Message.Type msgType = in.type(); - // All messages have the frame length, message type, and message itself. The frame length - // may optionally include the length of the body data, depending on what message is being - // sent. - int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); - long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); - ByteBuf header = ctx.alloc().heapBuffer(headerLength); - header.writeLong(frameLength); - msgType.encode(header); - in.encode(header); - assert header.writableBytes() == 0; - - if (body != null) { - // We transfer ownership of the reference on in.body() to MessageWithHeader. - // This reference will be freed when MessageWithHeader.deallocate() is called. - out.add(new MessageWithHeader(in.body(), header, body, bodyLength)); - } else { - out.add(header); - } - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java deleted file mode 100644 index 66227f96a1..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.io.IOException; -import java.nio.channels.WritableByteChannel; -import javax.annotation.Nullable; - -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ReferenceCountUtil; - -import org.apache.spark.network.buffer.ManagedBuffer; - -/** - * A wrapper message that holds two separate pieces (a header and a body). - * - * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. - */ -class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { - - @Nullable private final ManagedBuffer managedBuffer; - private final ByteBuf header; - private final int headerLength; - private final Object body; - private final long bodyLength; - private long totalBytesTransferred; - - /** - * Construct a new MessageWithHeader. - * - * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to - * be passed in so that the buffer can be freed when this message is - * deallocated. Ownership of the caller's reference to this buffer is - * transferred to this class, so if the caller wants to continue to use the - * ManagedBuffer in other messages then they will need to call retain() on - * it before passing it to this constructor. This may be null if and only if - * `body` is a {@link FileRegion}. - * @param header the message header. - * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. - * @param bodyLength the length of the message body, in bytes. - */ - MessageWithHeader( - @Nullable ManagedBuffer managedBuffer, - ByteBuf header, - Object body, - long bodyLength) { - Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, - "Body must be a ByteBuf or a FileRegion."); - this.managedBuffer = managedBuffer; - this.header = header; - this.headerLength = header.readableBytes(); - this.body = body; - this.bodyLength = bodyLength; - } - - @Override - public long count() { - return headerLength + bodyLength; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transfered() { - return totalBytesTransferred; - } - - /** - * This code is more complicated than you would think because we might require multiple - * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. - * - * The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transfered()). - */ - @Override - public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); - // Bytes written for header in this call. - long writtenHeader = 0; - if (header.readableBytes() > 0) { - writtenHeader = copyByteBuf(header, target); - totalBytesTransferred += writtenHeader; - if (header.readableBytes() > 0) { - return writtenHeader; - } - } - - // Bytes written for body in this call. - long writtenBody = 0; - if (body instanceof FileRegion) { - writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); - } else if (body instanceof ByteBuf) { - writtenBody = copyByteBuf((ByteBuf) body, target); - } - totalBytesTransferred += writtenBody; - - return writtenHeader + writtenBody; - } - - @Override - protected void deallocate() { - header.release(); - ReferenceCountUtil.release(body); - if (managedBuffer != null) { - managedBuffer.release(); - } - } - - private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - int written = target.write(buf.nioBuffer()); - buf.skipBytes(written); - return written; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java deleted file mode 100644 index efe0470f35..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * A RPC that does not expect a reply, which is handled by a remote - * {@link org.apache.spark.network.server.RpcHandler}. - */ -public final class OneWayMessage extends AbstractMessage implements RequestMessage { - - public OneWayMessage(ManagedBuffer body) { - super(body, true); - } - - @Override - public Type type() { return Type.OneWayMessage; } - - @Override - public int encodedLength() { - // The integer (a.k.a. the body size) is not really used, since that information is already - // encoded in the frame length. But this maintains backwards compatibility with versions of - // RpcRequest that use Encoders.ByteArrays. - return 4; - } - - @Override - public void encode(ByteBuf buf) { - // See comment in encodedLength(). - buf.writeInt((int) body().size()); - } - - public static OneWayMessage decode(ByteBuf buf) { - // See comment in encodedLength(). - buf.readInt(); - return new OneWayMessage(new NettyManagedBuffer(buf.retain())); - } - - @Override - public int hashCode() { - return Objects.hashCode(body()); - } - - @Override - public boolean equals(Object other) { - if (other instanceof OneWayMessage) { - OneWayMessage o = (OneWayMessage) other; - return super.equals(o); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("body", body()) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java deleted file mode 100644 index 31b15bb17a..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import org.apache.spark.network.protocol.Message; - -/** Messages from the client to the server. */ -public interface RequestMessage extends Message { - // token interface -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java deleted file mode 100644 index 6edffd11cf..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import org.apache.spark.network.protocol.Message; - -/** Messages from the server to the client. */ -public interface ResponseMessage extends Message { - // token interface -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java deleted file mode 100644 index a76624ef5d..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure extends AbstractMessage implements ResponseMessage { - public final long requestId; - public final String errorString; - - public RpcFailure(long requestId, String errorString) { - this.requestId = requestId; - this.errorString = errorString; - } - - @Override - public Type type() { return Type.RpcFailure; } - - @Override - public int encodedLength() { - return 8 + Encoders.Strings.encodedLength(errorString); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - Encoders.Strings.encode(buf, errorString); - } - - public static RpcFailure decode(ByteBuf buf) { - long requestId = buf.readLong(); - String errorString = Encoders.Strings.decode(buf); - return new RpcFailure(requestId, errorString); - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, errorString); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RpcFailure) { - RpcFailure o = (RpcFailure) other; - return requestId == o.requestId && errorString.equals(o.errorString); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("errorString", errorString) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java deleted file mode 100644 index 96213794a8..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. - * This will correspond to a single - * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). - */ -public final class RpcRequest extends AbstractMessage implements RequestMessage { - /** Used to link an RPC request with its response. */ - public final long requestId; - - public RpcRequest(long requestId, ManagedBuffer message) { - super(message, true); - this.requestId = requestId; - } - - @Override - public Type type() { return Type.RpcRequest; } - - @Override - public int encodedLength() { - // The integer (a.k.a. the body size) is not really used, since that information is already - // encoded in the frame length. But this maintains backwards compatibility with versions of - // RpcRequest that use Encoders.ByteArrays. - return 8 + 4; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - // See comment in encodedLength(). - buf.writeInt((int) body().size()); - } - - public static RpcRequest decode(ByteBuf buf) { - long requestId = buf.readLong(); - // See comment in encodedLength(). - buf.readInt(); - return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, body()); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RpcRequest) { - RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && super.equals(o); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("body", body()) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java deleted file mode 100644 index bae866e14a..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse extends AbstractResponseMessage { - public final long requestId; - - public RpcResponse(long requestId, ManagedBuffer message) { - super(message, true); - this.requestId = requestId; - } - - @Override - public Type type() { return Type.RpcResponse; } - - @Override - public int encodedLength() { - // The integer (a.k.a. the body size) is not really used, since that information is already - // encoded in the frame length. But this maintains backwards compatibility with versions of - // RpcRequest that use Encoders.ByteArrays. - return 8 + 4; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - // See comment in encodedLength(). - buf.writeInt((int) body().size()); - } - - @Override - public ResponseMessage createFailureResponse(String error) { - return new RpcFailure(requestId, error); - } - - public static RpcResponse decode(ByteBuf buf) { - long requestId = buf.readLong(); - // See comment in encodedLength(). - buf.readInt(); - return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, body()); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RpcResponse) { - RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && super.equals(o); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("body", body()) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java deleted file mode 100644 index d46a263884..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** -* Encapsulates a request for a particular chunk of a stream. -*/ -public final class StreamChunkId implements Encodable { - public final long streamId; - public final int chunkIndex; - - public StreamChunkId(long streamId, int chunkIndex) { - this.streamId = streamId; - this.chunkIndex = chunkIndex; - } - - @Override - public int encodedLength() { - return 8 + 4; - } - - public void encode(ByteBuf buffer) { - buffer.writeLong(streamId); - buffer.writeInt(chunkIndex); - } - - public static StreamChunkId decode(ByteBuf buffer) { - assert buffer.readableBytes() >= 8 + 4; - long streamId = buffer.readLong(); - int chunkIndex = buffer.readInt(); - return new StreamChunkId(streamId, chunkIndex); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamId, chunkIndex); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamChunkId) { - StreamChunkId o = (StreamChunkId) other; - return streamId == o.streamId && chunkIndex == o.chunkIndex; - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("chunkIndex", chunkIndex) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java deleted file mode 100644 index 26747ee55b..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Message indicating an error when transferring a stream. - */ -public final class StreamFailure extends AbstractMessage implements ResponseMessage { - public final String streamId; - public final String error; - - public StreamFailure(String streamId, String error) { - this.streamId = streamId; - this.error = error; - } - - @Override - public Type type() { return Type.StreamFailure; } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); - Encoders.Strings.encode(buf, error); - } - - public static StreamFailure decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - String error = Encoders.Strings.decode(buf); - return new StreamFailure(streamId, error); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamId, error); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamFailure) { - StreamFailure o = (StreamFailure) other; - return streamId.equals(o.streamId) && error.equals(o.error); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("error", error) - .toString(); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java deleted file mode 100644 index 35af5a84ba..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Request to stream data from the remote end. - *

- * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before - * the data can be streamed. - */ -public final class StreamRequest extends AbstractMessage implements RequestMessage { - public final String streamId; - - public StreamRequest(String streamId) { - this.streamId = streamId; - } - - @Override - public Type type() { return Type.StreamRequest; } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(streamId); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); - } - - public static StreamRequest decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - return new StreamRequest(streamId); - } - - @Override - public int hashCode() { - return Objects.hashCode(streamId); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamRequest) { - StreamRequest o = (StreamRequest) other; - return streamId.equals(o.streamId); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .toString(); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java deleted file mode 100644 index 51b899930f..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Response to {@link StreamRequest} when the stream has been successfully opened. - *

- * Note the message itself does not contain the stream data. That is written separately by the - * sender. The receiver is expected to set a temporary channel handler that will consume the - * number of bytes this message says the stream has. - */ -public final class StreamResponse extends AbstractResponseMessage { - public final String streamId; - public final long byteCount; - - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } - - @Override - public Type type() { return Type.StreamResponse; } - - @Override - public int encodedLength() { - return 8 + Encoders.Strings.encodedLength(streamId); - } - - /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, streamId); - buf.writeLong(byteCount); - } - - @Override - public ResponseMessage createFailureResponse(String error) { - return new StreamFailure(streamId, error); - } - - public static StreamResponse decode(ByteBuf buf) { - String streamId = Encoders.Strings.decode(buf); - long byteCount = buf.readLong(); - return new StreamResponse(streamId, byteCount, null); - } - - @Override - public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); - } - - @Override - public boolean equals(Object other) { - if (other instanceof StreamResponse) { - StreamResponse o = (StreamResponse) other; - return byteCount == o.byteCount && streamId.equals(o.streamId); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("byteCount", byteCount) - .add("body", body()) - .toString(); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java deleted file mode 100644 index 68381037d6..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.io.IOException; -import java.nio.ByteBuffer; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslException; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The - * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. - */ -public class SaslClientBootstrap implements TransportClientBootstrap { - private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); - - private final boolean encrypt; - private final TransportConf conf; - private final String appId; - private final SecretKeyHolder secretKeyHolder; - - public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { - this(conf, appId, secretKeyHolder, false); - } - - public SaslClientBootstrap( - TransportConf conf, - String appId, - SecretKeyHolder secretKeyHolder, - boolean encrypt) { - this.conf = conf; - this.appId = appId; - this.secretKeyHolder = secretKeyHolder; - this.encrypt = encrypt; - } - - /** - * Performs SASL authentication by sending a token, and then proceeding with the SASL - * challenge-response tokens until we either successfully authenticate or throw an exception - * due to mismatch. - */ - @Override - public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); - try { - byte[] payload = saslClient.firstToken(); - - while (!saslClient.isComplete()) { - SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); - msg.encode(buf); - buf.writeBytes(msg.body().nioByteBuffer()); - - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); - payload = saslClient.response(JavaUtils.bufferToArray(response)); - } - - client.setClientId(appId); - - if (encrypt) { - if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { - throw new RuntimeException( - new SaslException("Encryption requests by negotiated non-encrypted connection.")); - } - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); - saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); - } - } catch (IOException ioe) { - throw new RuntimeException(ioe); - } finally { - if (saslClient != null) { - try { - // Once authentication is complete, the server will trust all remaining communication. - saslClient.dispose(); - } catch (RuntimeException e) { - logger.error("Error while disposing SASL client", e); - } - } - } - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java deleted file mode 100644 index 127335e4d3..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; -import java.util.List; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.channel.FileRegion; -import io.netty.handler.codec.MessageToMessageDecoder; -import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ReferenceCountUtil; - -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.NettyUtils; - -/** - * Provides SASL-based encription for transport channels. The single method exposed by this - * class installs the needed channel handlers on a connected channel. - */ -class SaslEncryption { - - @VisibleForTesting - static final String ENCRYPTION_HANDLER_NAME = "saslEncryption"; - - /** - * Adds channel handlers that perform encryption / decryption of data using SASL. - * - * @param channel The channel. - * @param backend The SASL backend. - * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control - * memory usage. - */ - static void addToChannel( - Channel channel, - SaslEncryptionBackend backend, - int maxOutboundBlockSize) { - channel.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) - .addFirst("saslDecryption", new DecryptionHandler(backend)) - .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); - } - - private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { - - private final int maxOutboundBlockSize; - private final SaslEncryptionBackend backend; - - EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) { - this.backend = backend; - this.maxOutboundBlockSize = maxOutboundBlockSize; - } - - /** - * Wrap the incoming message in an implementation that will perform encryption lazily. This is - * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in - * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it - * does not guarantee any ordering. - */ - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - - ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise); - } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - try { - backend.dispose(); - } finally { - super.handlerRemoved(ctx); - } - } - - } - - private static class DecryptionHandler extends MessageToMessageDecoder { - - private final SaslEncryptionBackend backend; - - DecryptionHandler(SaslEncryptionBackend backend) { - this.backend = backend; - } - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) - throws Exception { - - byte[] data; - int offset; - int length = msg.readableBytes(); - if (msg.hasArray()) { - data = msg.array(); - offset = msg.arrayOffset(); - msg.skipBytes(length); - } else { - data = new byte[length]; - msg.readBytes(data); - offset = 0; - } - - out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length))); - } - - } - - @VisibleForTesting - static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { - - private final SaslEncryptionBackend backend; - private final boolean isByteBuf; - private final ByteBuf buf; - private final FileRegion region; - - /** - * A channel used to buffer input data for encryption. The channel has an upper size bound - * so that if the input is larger than the allowed buffer, it will be broken into multiple - * chunks. - */ - private final ByteArrayWritableChannel byteChannel; - - private ByteBuf currentHeader; - private ByteBuffer currentChunk; - private long currentChunkSize; - private long currentReportedBytes; - private long unencryptedChunkSize; - private long transferred; - - EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) { - Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, - "Unrecognized message type: %s", msg.getClass().getName()); - this.backend = backend; - this.isByteBuf = msg instanceof ByteBuf; - this.buf = isByteBuf ? (ByteBuf) msg : null; - this.region = isByteBuf ? null : (FileRegion) msg; - this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); - } - - /** - * Returns the size of the original (unencrypted) message. - * - * This makes assumptions about how netty treats FileRegion instances, because there's no way - * to know beforehand what will be the size of the encrypted message. Namely, it assumes - * that netty will try to transfer data from this message while - * transfered() < count(). So these two methods return, technically, wrong data, - * but netty doesn't know better. - */ - @Override - public long count() { - return isByteBuf ? buf.readableBytes() : region.count(); - } - - @Override - public long position() { - return 0; - } - - /** - * Returns an approximation of the amount of data transferred. See {@link #count()}. - */ - @Override - public long transfered() { - return transferred; - } - - /** - * Transfers data from the original message to the channel, encrypting it in the process. - * - * This method also breaks down the original message into smaller chunks when needed. This - * is done to keep memory usage under control. This avoids having to copy the whole message - * data into memory at once, and can avoid ballooning memory usage when transferring large - * messages such as shuffle blocks. - * - * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward - * until a whole chunk has been written. This is done because the code can't use the actual - * number of bytes written to the channel as the transferred count (see {@link #count()}). - * Instead, once an encrypted chunk is written to the output (including its header), the - * size of the original block will be added to the {@link #transfered()} amount. - */ - @Override - public long transferTo(final WritableByteChannel target, final long position) - throws IOException { - - Preconditions.checkArgument(position == transfered(), "Invalid position."); - - long reportedWritten = 0L; - long actuallyWritten = 0L; - do { - if (currentChunk == null) { - nextChunk(); - } - - if (currentHeader.readableBytes() > 0) { - int bytesWritten = target.write(currentHeader.nioBuffer()); - currentHeader.skipBytes(bytesWritten); - actuallyWritten += bytesWritten; - if (currentHeader.readableBytes() > 0) { - // Break out of loop if there are still header bytes left to write. - break; - } - } - - actuallyWritten += target.write(currentChunk); - if (!currentChunk.hasRemaining()) { - // Only update the count of written bytes once a full chunk has been written. - // See method javadoc. - long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes; - reportedWritten += chunkBytesRemaining; - transferred += chunkBytesRemaining; - currentHeader.release(); - currentHeader = null; - currentChunk = null; - currentChunkSize = 0; - currentReportedBytes = 0; - } - } while (currentChunk == null && transfered() + reportedWritten < count()); - - // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, - // we return 1 until we can (i.e. until the reported count would actually match the size - // of the current chunk), at which point we resort to returning 0 so that the counts still - // match, at the cost of some performance. That situation should be rare, though. - if (reportedWritten != 0L) { - return reportedWritten; - } - - if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) { - transferred += 1L; - currentReportedBytes += 1L; - return 1L; - } - - return 0L; - } - - private void nextChunk() throws IOException { - byteChannel.reset(); - if (isByteBuf) { - int copied = byteChannel.write(buf.nioBuffer()); - buf.skipBytes(copied); - } else { - region.transferTo(byteChannel, region.transfered()); - } - - byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); - this.currentChunk = ByteBuffer.wrap(encrypted); - this.currentChunkSize = encrypted.length; - this.currentHeader = Unpooled.copyLong(8 + currentChunkSize); - this.unencryptedChunkSize = byteChannel.length(); - } - - @Override - protected void deallocate() { - if (currentHeader != null) { - currentHeader.release(); - } - if (buf != null) { - buf.release(); - } - if (region != null) { - region.release(); - } - } - - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java deleted file mode 100644 index 89b78bc7e1..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import javax.security.sasl.SaslException; - -interface SaslEncryptionBackend { - - /** Disposes of resources used by the backend. */ - void dispose(); - - /** Encrypt data. */ - byte[] wrap(byte[] data, int offset, int len) throws SaslException; - - /** Decrypt data. */ - byte[] unwrap(byte[] data, int offset, int len) throws SaslException; - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java deleted file mode 100644 index e52b526f09..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.buffer.NettyManagedBuffer; -import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.protocol.AbstractMessage; - -/** - * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged - * with the given appId. This appId allows a single SaslRpcHandler to multiplex different - * applications which may be using different sets of credentials. - */ -class SaslMessage extends AbstractMessage { - - /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xEA; - - public final String appId; - - public SaslMessage(String appId, byte[] message) { - this(appId, Unpooled.wrappedBuffer(message)); - } - - public SaslMessage(String appId, ByteBuf message) { - super(new NettyManagedBuffer(message), true); - this.appId = appId; - } - - @Override - public Type type() { return Type.User; } - - @Override - public int encodedLength() { - // The integer (a.k.a. the body size) is not really used, since that information is already - // encoded in the frame length. But this maintains backwards compatibility with versions of - // RpcRequest that use Encoders.ByteArrays. - return 1 + Encoders.Strings.encodedLength(appId) + 4; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.Strings.encode(buf, appId); - // See comment in encodedLength(). - buf.writeInt((int) body().size()); - } - - public static SaslMessage decode(ByteBuf buf) { - if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else" - + " (maybe your client does not have SASL enabled?)"); - } - - String appId = Encoders.Strings.decode(buf); - // See comment in encodedLength(). - buf.readInt(); - return new SaslMessage(appId, buf.retain()); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java deleted file mode 100644 index c41f5b6873..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.io.IOException; -import java.nio.ByteBuffer; -import javax.security.sasl.Sasl; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * RPC Handler which performs SASL authentication before delegating to a child RPC handler. - * The delegate will only receive messages if the given connection has been successfully - * authenticated. A connection may be authenticated at most once. - * - * Note that the authentication process consists of multiple challenge-response pairs, each of - * which are individual RPCs. - */ -class SaslRpcHandler extends RpcHandler { - private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); - - /** Transport configuration. */ - private final TransportConf conf; - - /** The client channel. */ - private final Channel channel; - - /** RpcHandler we will delegate to for authenticated connections. */ - private final RpcHandler delegate; - - /** Class which provides secret keys which are shared by server and client on a per-app basis. */ - private final SecretKeyHolder secretKeyHolder; - - private SparkSaslServer saslServer; - private boolean isComplete; - - SaslRpcHandler( - TransportConf conf, - Channel channel, - RpcHandler delegate, - SecretKeyHolder secretKeyHolder) { - this.conf = conf; - this.channel = channel; - this.delegate = delegate; - this.secretKeyHolder = secretKeyHolder; - this.saslServer = null; - this.isComplete = false; - } - - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - if (isComplete) { - // Authentication complete, delegate to base handler. - delegate.receive(client, message, callback); - return; - } - - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } - - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); - } - callback.onSuccess(ByteBuffer.wrap(response)); - - // Setup encryption after the SASL response is sent, otherwise the client can't parse the - // response. It's ok to change the channel pipeline here since we are processing an incoming - // message, so the pipeline is busy and no new incoming messages will be fed to it before this - // method returns. This assumes that the code ensures, through other means, that no outbound - // messages are being written to the channel while negotiation is still going on. - if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { - logger.debug("Enabling encryption for channel {}", client); - SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; - } - } - } - - @Override - public void receive(TransportClient client, ByteBuffer message) { - delegate.receive(client, message); - } - - @Override - public StreamManager getStreamManager() { - return delegate.getStreamManager(); - } - - @Override - public void channelActive(TransportClient client) { - delegate.channelActive(client); - } - - @Override - public void channelInactive(TransportClient client) { - try { - delegate.channelInactive(client); - } finally { - if (saslServer != null) { - saslServer.dispose(); - } - } - } - - @Override - public void exceptionCaught(Throwable cause, TransportClient client) { - delegate.exceptionCaught(cause, client); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java deleted file mode 100644 index f2f983856f..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import io.netty.channel.Channel; - -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.util.TransportConf; - -/** - * A bootstrap which is executed on a TransportServer's client channel once a client connects - * to the server. This allows customizing the client channel to allow for things such as SASL - * authentication. - */ -public class SaslServerBootstrap implements TransportServerBootstrap { - - private final TransportConf conf; - private final SecretKeyHolder secretKeyHolder; - - public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { - this.conf = conf; - this.secretKeyHolder = secretKeyHolder; - } - - /** - * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL - * negotiation. - */ - public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { - return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java deleted file mode 100644 index 81d5766794..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -/** - * Interface for getting a secret key associated with some application. - */ -public interface SecretKeyHolder { - /** - * Gets an appropriate SASL User for the given appId. - * @throws IllegalArgumentException if the given appId is not associated with a SASL user. - */ - String getSaslUser(String appId); - - /** - * Gets an appropriate SASL secret key for the given appId. - * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. - */ - String getSecretKey(String appId); -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java deleted file mode 100644 index 94685e91b8..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.io.IOException; -import java.util.Map; -import javax.security.auth.callback.Callback; -import javax.security.auth.callback.CallbackHandler; -import javax.security.auth.callback.NameCallback; -import javax.security.auth.callback.PasswordCallback; -import javax.security.auth.callback.UnsupportedCallbackException; -import javax.security.sasl.RealmCallback; -import javax.security.sasl.RealmChoiceCallback; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; - -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import static org.apache.spark.network.sasl.SparkSaslServer.*; - -/** - * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the - * initial state to the "authenticated" state. This client initializes the protocol via a - * firstToken, which is then followed by a set of challenges and responses. - */ -public class SparkSaslClient implements SaslEncryptionBackend { - private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); - - private final String secretKeyId; - private final SecretKeyHolder secretKeyHolder; - private final String expectedQop; - private SaslClient saslClient; - - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { - this.secretKeyId = secretKeyId; - this.secretKeyHolder = secretKeyHolder; - this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; - - Map saslProps = ImmutableMap.builder() - .put(Sasl.QOP, expectedQop) - .build(); - try { - this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, - saslProps, new ClientCallbackHandler()); - } catch (SaslException e) { - throw Throwables.propagate(e); - } - } - - /** Used to initiate SASL handshake with server. */ - public synchronized byte[] firstToken() { - if (saslClient != null && saslClient.hasInitialResponse()) { - try { - return saslClient.evaluateChallenge(new byte[0]); - } catch (SaslException e) { - throw Throwables.propagate(e); - } - } else { - return new byte[0]; - } - } - - /** Determines whether the authentication exchange has completed. */ - public synchronized boolean isComplete() { - return saslClient != null && saslClient.isComplete(); - } - - /** Returns the value of a negotiated property. */ - public Object getNegotiatedProperty(String name) { - return saslClient.getNegotiatedProperty(name); - } - - /** - * Respond to server's SASL token. - * @param token contains server's SASL token - * @return client's response SASL token - */ - public synchronized byte[] response(byte[] token) { - try { - return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; - } catch (SaslException e) { - throw Throwables.propagate(e); - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - @Override - public synchronized void dispose() { - if (saslClient != null) { - try { - saslClient.dispose(); - } catch (SaslException e) { - // ignore - } finally { - saslClient = null; - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class ClientCallbackHandler implements CallbackHandler { - @Override - public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { - - for (Callback callback : callbacks) { - if (callback instanceof NameCallback) { - logger.trace("SASL client callback: setting username"); - NameCallback nc = (NameCallback) callback; - nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); - } else if (callback instanceof PasswordCallback) { - logger.trace("SASL client callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); - } else if (callback instanceof RealmCallback) { - logger.trace("SASL client callback: setting realm"); - RealmCallback rc = (RealmCallback) callback; - rc.setText(rc.getDefaultText()); - } else if (callback instanceof RealmChoiceCallback) { - // ignore (?) - } else { - throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); - } - } - } - } - - @Override - public byte[] wrap(byte[] data, int offset, int len) throws SaslException { - return saslClient.wrap(data, offset, len); - } - - @Override - public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { - return saslClient.unwrap(data, offset, len); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java deleted file mode 100644 index 431cb67a2a..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import javax.security.auth.callback.Callback; -import javax.security.auth.callback.CallbackHandler; -import javax.security.auth.callback.NameCallback; -import javax.security.auth.callback.PasswordCallback; -import javax.security.auth.callback.UnsupportedCallbackException; -import javax.security.sasl.AuthorizeCallback; -import javax.security.sasl.RealmCallback; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslException; -import javax.security.sasl.SaslServer; -import java.io.IOException; -import java.util.Map; - -import com.google.common.base.Charsets; -import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableMap; -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.base64.Base64; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the - * initial state to the "authenticated" state. (It is not a server in the sense of accepting - * connections on some socket.) - */ -public class SparkSaslServer implements SaslEncryptionBackend { - private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - static final String DEFAULT_REALM = "default"; - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - static final String DIGEST = "DIGEST-MD5"; - - /** - * Quality of protection value that includes encryption. - */ - static final String QOP_AUTH_CONF = "auth-conf"; - - /** - * Quality of protection value that does not include encryption. - */ - static final String QOP_AUTH = "auth"; - - /** Identifier for a certain secret key within the secretKeyHolder. */ - private final String secretKeyId; - private final SecretKeyHolder secretKeyHolder; - private SaslServer saslServer; - - public SparkSaslServer( - String secretKeyId, - SecretKeyHolder secretKeyHolder, - boolean alwaysEncrypt) { - this.secretKeyId = secretKeyId; - this.secretKeyHolder = secretKeyHolder; - - // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption - // is listed first since it's preferred over the non-encrypted one (if the client also - // lists both in the request). - String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH); - Map saslProps = ImmutableMap.builder() - .put(Sasl.SERVER_AUTH, "true") - .put(Sasl.QOP, qop) - .build(); - try { - this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, - new DigestCallbackHandler()); - } catch (SaslException e) { - throw Throwables.propagate(e); - } - } - - /** - * Determines whether the authentication exchange has completed successfully. - */ - public synchronized boolean isComplete() { - return saslServer != null && saslServer.isComplete(); - } - - /** Returns the value of a negotiated property. */ - public Object getNegotiatedProperty(String name) { - return saslServer.getNegotiatedProperty(name); - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - public synchronized byte[] response(byte[] token) { - try { - return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; - } catch (SaslException e) { - throw Throwables.propagate(e); - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - @Override - public synchronized void dispose() { - if (saslServer != null) { - try { - saslServer.dispose(); - } catch (SaslException e) { - // ignore - } finally { - saslServer = null; - } - } - } - - @Override - public byte[] wrap(byte[] data, int offset, int len) throws SaslException { - return saslServer.wrap(data, offset, len); - } - - @Override - public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { - return saslServer.unwrap(data, offset, len); - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. - */ - private class DigestCallbackHandler implements CallbackHandler { - @Override - public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { - for (Callback callback : callbacks) { - if (callback instanceof NameCallback) { - logger.trace("SASL server callback: setting username"); - NameCallback nc = (NameCallback) callback; - nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); - } else if (callback instanceof PasswordCallback) { - logger.trace("SASL server callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); - } else if (callback instanceof RealmCallback) { - logger.trace("SASL server callback: setting realm"); - RealmCallback rc = (RealmCallback) callback; - rc.setText(rc.getDefaultText()); - } else if (callback instanceof AuthorizeCallback) { - AuthorizeCallback ac = (AuthorizeCallback) callback; - String authId = ac.getAuthenticationID(); - String authzId = ac.getAuthorizationID(); - ac.setAuthorized(authId.equals(authzId)); - if (ac.isAuthorized()) { - ac.setAuthorizedID(authzId); - } - logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); - } else { - throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); - } - } - } - } - - /* Encode a byte[] identifier as a Base64-encoded string. */ - public static String encodeIdentifier(String identifier) { - Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8); - } - - /** Encode a password as a base64-encoded char[] array. */ - public static char[] encodePassword(String password) { - Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8).toCharArray(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java deleted file mode 100644 index 4a1f28e9ff..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import org.apache.spark.network.protocol.Message; - -/** - * Handles either request or response messages coming off of Netty. A MessageHandler instance - * is associated with a single Netty Channel (though it may have multiple clients on the same - * Channel.) - */ -public abstract class MessageHandler { - /** Handles the receipt of a single message. */ - public abstract void handle(T message) throws Exception; - - /** Invoked when the channel this MessageHandler is on is active. */ - public abstract void channelActive(); - - /** Invoked when an exception was caught on the Channel. */ - public abstract void exceptionCaught(Throwable cause); - - /** Invoked when the channel this MessageHandler is on is inactive. */ - public abstract void channelInactive(); -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java deleted file mode 100644 index 6ed61da5c7..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.nio.ByteBuffer; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; - -/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ -public class NoOpRpcHandler extends RpcHandler { - private final StreamManager streamManager; - - public NoOpRpcHandler() { - streamManager = new OneForOneStreamManager(); - } - - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - throw new UnsupportedOperationException("Cannot handle messages"); - } - - @Override - public StreamManager getStreamManager() { return streamManager; } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java deleted file mode 100644 index ea9e735e0a..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.util.Iterator; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import com.google.common.base.Preconditions; -import io.netty.channel.Channel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.TransportClient; - -/** - * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually - * fetched as chunks by the client. Each registered buffer is one chunk. - */ -public class OneForOneStreamManager extends StreamManager { - private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); - - private final AtomicLong nextStreamId; - private final ConcurrentHashMap streams; - - /** State of a single stream. */ - private static class StreamState { - final String appId; - final Iterator buffers; - - // The channel associated to the stream - Channel associatedChannel = null; - - // Used to keep track of the index of the buffer that the user has retrieved, just to ensure - // that the caller only requests each chunk one at a time, in order. - int curChunk = 0; - - StreamState(String appId, Iterator buffers) { - this.appId = appId; - this.buffers = Preconditions.checkNotNull(buffers); - } - } - - public OneForOneStreamManager() { - // For debugging purposes, start with a random stream id to help identifying different streams. - // This does not need to be globally unique, only unique to this class. - nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); - streams = new ConcurrentHashMap(); - } - - @Override - public void registerChannel(Channel channel, long streamId) { - if (streams.containsKey(streamId)) { - streams.get(streamId).associatedChannel = channel; - } - } - - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - StreamState state = streams.get(streamId); - if (chunkIndex != state.curChunk) { - throw new IllegalStateException(String.format( - "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); - } else if (!state.buffers.hasNext()) { - throw new IllegalStateException(String.format( - "Requested chunk index beyond end %s", chunkIndex)); - } - state.curChunk += 1; - ManagedBuffer nextChunk = state.buffers.next(); - - if (!state.buffers.hasNext()) { - logger.trace("Removing stream id {}", streamId); - streams.remove(streamId); - } - - return nextChunk; - } - - @Override - public void connectionTerminated(Channel channel) { - // Close all streams which have been associated with the channel. - for (Map.Entry entry: streams.entrySet()) { - StreamState state = entry.getValue(); - if (state.associatedChannel == channel) { - streams.remove(entry.getKey()); - - // Release all remaining buffers. - while (state.buffers.hasNext()) { - state.buffers.next().release(); - } - } - } - } - - @Override - public void checkAuthorization(TransportClient client, long streamId) { - if (client.getClientId() != null) { - StreamState state = streams.get(streamId); - Preconditions.checkArgument(state != null, "Unknown stream ID."); - if (!client.getClientId().equals(state.appId)) { - throw new SecurityException(String.format( - "Client %s not authorized to read stream %d (app %s).", - client.getClientId(), - streamId, - state.appId)); - } - } - } - - /** - * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to - * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a - * client connection is closed before the iterator is fully drained, then the remaining buffers - * will all be release()'d. - * - * If an app ID is provided, only callers who've authenticated with the given app ID will be - * allowed to fetch from this stream. - */ - public long registerStream(String appId, Iterator buffers) { - long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(appId, buffers)); - return myStreamId; - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java deleted file mode 100644 index a99c3015b0..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.nio.ByteBuffer; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; - -/** - * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. - */ -public abstract class RpcHandler { - - private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); - - /** - * Receive a single RPC message. Any exception thrown while in this method will be sent back to - * the client in string form as a standard RPC failure. - * - * This method will not be called in parallel for a single TransportClient (i.e., channel). - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param message The serialized bytes of the RPC. - * @param callback Callback which should be invoked exactly once upon success or failure of the - * RPC. - */ - public abstract void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback); - - /** - * Returns the StreamManager which contains the state about which streams are currently being - * fetched by a TransportClient. - */ - public abstract StreamManager getStreamManager(); - - /** - * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if - * any of the callback methods are called. - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param message The serialized bytes of the RPC. - */ - public void receive(TransportClient client, ByteBuffer message) { - receive(client, message, ONE_WAY_CALLBACK); - } - - /** - * Invoked when the channel associated with the given client is active. - */ - public void channelActive(TransportClient client) { } - - /** - * Invoked when the channel associated with the given client is inactive. - * No further requests will come from this client. - */ - public void channelInactive(TransportClient client) { } - - public void exceptionCaught(Throwable cause, TransportClient client) { } - - private static class OneWayRpcCallback implements RpcResponseCallback { - - private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); - - @Override - public void onSuccess(ByteBuffer response) { - logger.warn("Response provided for one-way RPC."); - } - - @Override - public void onFailure(Throwable e) { - logger.error("Error response provided for one-way RPC.", e); - } - - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java deleted file mode 100644 index 07f161a29c..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import io.netty.channel.Channel; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.TransportClient; - -/** - * The StreamManager is used to fetch individual chunks from a stream. This is used in - * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the - * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read - * by only one client connection, meaning that getChunk() for a particular stream will be called - * serially and that once the connection associated with the stream is closed, that stream will - * never be used again. - */ -public abstract class StreamManager { - /** - * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the - * client. A single stream will be associated with a single TCP connection, so this method - * will not be called in parallel for a particular stream. - * - * Chunks may be requested in any order, and requests may be repeated, but it is not required - * that implementations support this behavior. - * - * The returned ManagedBuffer will be release()'d after being written to the network. - * - * @param streamId id of a stream that has been previously registered with the StreamManager. - * @param chunkIndex 0-indexed chunk of the stream that's requested - */ - public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); - - /** - * Called in response to a stream() request. The returned data is streamed to the client - * through a single TCP connection. - * - * Note the streamId argument is not related to the similarly named argument in the - * {@link #getChunk(long, int)} method. - * - * @param streamId id of a stream that has been previously registered with the StreamManager. - * @return A managed buffer for the stream, or null if the stream was not found. - */ - public ManagedBuffer openStream(String streamId) { - throw new UnsupportedOperationException(); - } - - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. The getChunk() method will be called serially on this connection and once the - * connection is closed, the stream will never be used again, enabling cleanup. - * - * This must be called before the first getChunk() on the stream, but it may be invoked multiple - * times with the same channel and stream id. - */ - public void registerChannel(Channel channel, long streamId) { } - - /** - * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not - * to read from the associated streams again, so any state can be cleaned up. - */ - public void connectionTerminated(Channel channel) { } - - /** - * Verify that the client is authorized to read from the given stream. - * - * @throws SecurityException If client is not authorized. - */ - public void checkAuthorization(TransportClient client, long streamId) { } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java deleted file mode 100644 index 18a9b7887e..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.timeout.IdleState; -import io.netty.handler.timeout.IdleStateEvent; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.util.NettyUtils; - -/** - * The single Transport-level Channel handler which is used for delegating requests to the - * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. - * - * All channels created in the transport layer are bidirectional. When the Client initiates a Netty - * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server - * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server - * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the - * Client. - * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, - * for the Client's responses to the Server's requests. - * - * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. - * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic - * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not - * timeout if the client is continuously sending but getting no responses, for simplicity. - */ -public class TransportChannelHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); - - private final TransportClient client; - private final TransportResponseHandler responseHandler; - private final TransportRequestHandler requestHandler; - private final long requestTimeoutNs; - private final boolean closeIdleConnections; - - public TransportChannelHandler( - TransportClient client, - TransportResponseHandler responseHandler, - TransportRequestHandler requestHandler, - long requestTimeoutMs, - boolean closeIdleConnections) { - this.client = client; - this.responseHandler = responseHandler; - this.requestHandler = requestHandler; - this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000; - this.closeIdleConnections = closeIdleConnections; - } - - public TransportClient getClient() { - return client; - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), - cause); - requestHandler.exceptionCaught(cause); - responseHandler.exceptionCaught(cause); - ctx.close(); - } - - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - try { - requestHandler.channelActive(); - } catch (RuntimeException e) { - logger.error("Exception from request handler while registering channel", e); - } - try { - responseHandler.channelActive(); - } catch (RuntimeException e) { - logger.error("Exception from response handler while registering channel", e); - } - super.channelRegistered(ctx); - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - try { - requestHandler.channelInactive(); - } catch (RuntimeException e) { - logger.error("Exception from request handler while unregistering channel", e); - } - try { - responseHandler.channelInactive(); - } catch (RuntimeException e) { - logger.error("Exception from response handler while unregistering channel", e); - } - super.channelUnregistered(ctx); - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { - if (request instanceof RequestMessage) { - requestHandler.handle((RequestMessage) request); - } else { - responseHandler.handle((ResponseMessage) request); - } - } - - /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */ - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof IdleStateEvent) { - IdleStateEvent e = (IdleStateEvent) evt; - // See class comment for timeout semantics. In addition to ensuring we only timeout while - // there are outstanding requests, we also do a secondary consistency check to ensure - // there's no race between the idle timeout and incrementing the numOutstandingRequests - // (see SPARK-7003). - // - // To avoid a race between TransportClientFactory.createClient() and this code which could - // result in an inactive client being returned, this needs to run in a synchronized block. - synchronized (this) { - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - client.timeOut(); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - client.timeOut(); - ctx.close(); - } - } - } - } - ctx.fireUserEventTriggered(evt); - } - - public TransportResponseHandler getResponseHandler() { - return responseHandler; - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java deleted file mode 100644 index 296ced3db0..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.nio.ByteBuffer; - -import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.util.NettyUtils; - -/** - * A handler that processes requests from clients and writes chunk data back. Each handler is - * attached to a single Netty channel, and keeps track of which streams have been fetched via this - * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). - * - * The messages should have been processed by the pipeline setup by {@link TransportServer}. - */ -public class TransportRequestHandler extends MessageHandler { - private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); - - /** The Netty channel that this handler is associated with. */ - private final Channel channel; - - /** Client on the same channel allowing us to talk back to the requester. */ - private final TransportClient reverseClient; - - /** Handles all RPC messages. */ - private final RpcHandler rpcHandler; - - /** Returns each chunk part of a stream. */ - private final StreamManager streamManager; - - public TransportRequestHandler( - Channel channel, - TransportClient reverseClient, - RpcHandler rpcHandler) { - this.channel = channel; - this.reverseClient = reverseClient; - this.rpcHandler = rpcHandler; - this.streamManager = rpcHandler.getStreamManager(); - } - - @Override - public void exceptionCaught(Throwable cause) { - rpcHandler.exceptionCaught(cause, reverseClient); - } - - @Override - public void channelActive() { - rpcHandler.channelActive(reverseClient); - } - - @Override - public void channelInactive() { - if (streamManager != null) { - try { - streamManager.connectionTerminated(channel); - } catch (RuntimeException e) { - logger.error("StreamManager connectionTerminated() callback failed.", e); - } - } - rpcHandler.channelInactive(reverseClient); - } - - @Override - public void handle(RequestMessage request) { - if (request instanceof ChunkFetchRequest) { - processFetchRequest((ChunkFetchRequest) request); - } else if (request instanceof RpcRequest) { - processRpcRequest((RpcRequest) request); - } else if (request instanceof OneWayMessage) { - processOneWayMessage((OneWayMessage) request); - } else if (request instanceof StreamRequest) { - processStreamRequest((StreamRequest) request); - } else { - throw new IllegalArgumentException("Unknown request type: " + request); - } - } - - private void processFetchRequest(final ChunkFetchRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); - - logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); - - ManagedBuffer buf; - try { - streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - streamManager.registerChannel(channel, req.streamChunkId.streamId); - buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); - } catch (Exception e) { - logger.error(String.format( - "Error opening block %s for request from %s", req.streamChunkId, client), e); - respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); - return; - } - - respond(new ChunkFetchSuccess(req.streamChunkId, buf)); - } - - private void processStreamRequest(final StreamRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); - ManagedBuffer buf; - try { - buf = streamManager.openStream(req.streamId); - } catch (Exception e) { - logger.error(String.format( - "Error opening stream %s for request from %s", req.streamId, client), e); - respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); - return; - } - - if (buf != null) { - respond(new StreamResponse(req.streamId, buf.size(), buf)); - } else { - respond(new StreamFailure(req.streamId, String.format( - "Stream '%s' was not found.", req.streamId))); - } - } - - private void processRpcRequest(final RpcRequest req) { - try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); - } - - @Override - public void onFailure(Throwable e) { - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } - }); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } finally { - req.body().release(); - } - } - - private void processOneWayMessage(OneWayMessage req) { - try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); - } finally { - req.body().release(); - } - } - - /** - * Responds to a single message with some Encodable object. If a failure occurs while sending, - * it will be logged and the channel closed. - */ - private void respond(final Encodable result) { - final String remoteAddress = channel.remoteAddress().toString(); - channel.writeAndFlush(result).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); - } else { - logger.error(String.format("Error sending result %s to %s; closing connection", - result, remoteAddress), future.cause()); - channel.close(); - } - } - } - ); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java deleted file mode 100644 index baae235e02..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.io.Closeable; -import java.net.InetSocketAddress; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.socket.SocketChannel; -import org.apache.spark.network.util.JavaUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.util.IOMode; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Server for the efficient, low-level streaming service. - */ -public class TransportServer implements Closeable { - private final Logger logger = LoggerFactory.getLogger(TransportServer.class); - - private final TransportContext context; - private final TransportConf conf; - private final RpcHandler appRpcHandler; - private final List bootstraps; - - private ServerBootstrap bootstrap; - private ChannelFuture channelFuture; - private int port = -1; - - /** - * Creates a TransportServer that binds to the given host and the given port, or to any available - * if 0. If you don't want to bind to any special host, set "hostToBind" to null. - * */ - public TransportServer( - TransportContext context, - String hostToBind, - int portToBind, - RpcHandler appRpcHandler, - List bootstraps) { - this.context = context; - this.conf = context.getConf(); - this.appRpcHandler = appRpcHandler; - this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); - - try { - init(hostToBind, portToBind); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(this); - throw e; - } - } - - public int getPort() { - if (port == -1) { - throw new IllegalStateException("Server not initialized"); - } - return port; - } - - private void init(String hostToBind, int portToBind) { - - IOMode ioMode = IOMode.valueOf(conf.ioMode()); - EventLoopGroup bossGroup = - NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); - EventLoopGroup workerGroup = bossGroup; - - PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( - conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads()); - - bootstrap = new ServerBootstrap() - .group(bossGroup, workerGroup) - .channel(NettyUtils.getServerChannelClass(ioMode)) - .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator); - - if (conf.backLog() > 0) { - bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); - } - - if (conf.receiveBuf() > 0) { - bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); - } - - if (conf.sendBuf() > 0) { - bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); - } - - bootstrap.childHandler(new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel ch) throws Exception { - RpcHandler rpcHandler = appRpcHandler; - for (TransportServerBootstrap bootstrap : bootstraps) { - rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); - } - context.initializePipeline(ch, rpcHandler); - } - }); - - InetSocketAddress address = hostToBind == null ? - new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); - channelFuture = bootstrap.bind(address); - channelFuture.syncUninterruptibly(); - - port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); - logger.debug("Shuffle server started on port :" + port); - } - - @Override - public void close() { - if (channelFuture != null) { - // close is a local operation and should finish within milliseconds; timeout just to be safe - channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); - channelFuture = null; - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully(); - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully(); - } - bootstrap = null; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java deleted file mode 100644 index 05803ab1bb..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import io.netty.channel.Channel; - -/** - * A bootstrap which is executed on a TransportServer's client channel once a client connects - * to the server. This allows customizing the client channel to allow for things such as SASL - * authentication. - */ -public interface TransportServerBootstrap { - /** - * Customizes the channel to include new features, if needed. - * - * @param channel The connected channel opened by the client. - * @param rpcHandler The RPC handler for the server. - * @return The RPC handler to use for the channel. - */ - RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler); -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java deleted file mode 100644 index b141572004..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -/** - * A writable channel that stores the written data in a byte array in memory. - */ -public class ByteArrayWritableChannel implements WritableByteChannel { - - private final byte[] data; - private int offset; - - public ByteArrayWritableChannel(int size) { - this.data = new byte[size]; - } - - public byte[] getData() { - return data; - } - - public int length() { - return offset; - } - - /** Resets the channel so that writing to it will overwrite the existing buffer. */ - public void reset() { - offset = 0; - } - - /** - * Reads from the given buffer into the internal byte array. - */ - @Override - public int write(ByteBuffer src) { - int toTransfer = Math.min(src.remaining(), data.length - offset); - src.get(data, offset, toTransfer); - offset += toTransfer; - return toTransfer; - } - - @Override - public void close() { - - } - - @Override - public boolean isOpen() { - return true; - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java deleted file mode 100644 index a2f018373f..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.network.util; - -public enum ByteUnit { - BYTE (1), - KiB (1024L), - MiB ((long) Math.pow(1024L, 2L)), - GiB ((long) Math.pow(1024L, 3L)), - TiB ((long) Math.pow(1024L, 4L)), - PiB ((long) Math.pow(1024L, 5L)); - - private ByteUnit(long multiplier) { - this.multiplier = multiplier; - } - - // Interpret the provided number (d) with suffix (u) as this unit type. - // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k - public long convertFrom(long d, ByteUnit u) { - return u.convertTo(d, this); - } - - // Convert the provided number (d) interpreted as this unit type to unit type (u). - public long convertTo(long d, ByteUnit u) { - if (multiplier > u.multiplier) { - long ratio = multiplier / u.multiplier; - if (Long.MAX_VALUE / ratio < d) { - throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " - + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); - } - return d * ratio; - } else { - // Perform operations in this order to avoid potential overflow - // when computing d * multiplier - return d / (u.multiplier / multiplier); - } - } - - public double toBytes(long d) { - if (d < 0) { - throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); - } - return d * multiplier; - } - - public long toKiB(long d) { return convertTo(d, KiB); } - public long toMiB(long d) { return convertTo(d, MiB); } - public long toGiB(long d) { return convertTo(d, GiB); } - public long toTiB(long d) { return convertTo(d, TiB); } - public long toPiB(long d) { return convertTo(d, PiB); } - - private final long multiplier; -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java deleted file mode 100644 index d944d9da1c..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.util.NoSuchElementException; - -/** - * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration. - */ -public abstract class ConfigProvider { - /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ - public abstract String get(String name); - - public String get(String name, String defaultValue) { - try { - return get(name); - } catch (NoSuchElementException e) { - return defaultValue; - } - } - - public int getInt(String name, int defaultValue) { - return Integer.parseInt(get(name, Integer.toString(defaultValue))); - } - - public long getLong(String name, long defaultValue) { - return Long.parseLong(get(name, Long.toString(defaultValue))); - } - - public double getDouble(String name, double defaultValue) { - return Double.parseDouble(get(name, Double.toString(defaultValue))); - } - - public boolean getBoolean(String name, boolean defaultValue) { - return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java deleted file mode 100644 index 6b208d95bb..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -/** - * Selector for which form of low-level IO we should use. - * NIO is always available, while EPOLL is only available on Linux. - * AUTO is used to select EPOLL if it's available, or NIO otherwise. - */ -public enum IOMode { - NIO, EPOLL -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java deleted file mode 100644 index b3d8e0cd7c..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.io.Closeable; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.concurrent.TimeUnit; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import com.google.common.base.Charsets; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import io.netty.buffer.Unpooled; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * General utilities available in the network package. Many of these are sourced from Spark's - * own Utils, just accessible within this package. - */ -public class JavaUtils { - private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); - - /** - * Define a default value for driver memory here since this value is referenced across the code - * base and nearly all files already use Utils.scala - */ - public static final long DEFAULT_DRIVER_MEM_MB = 1024; - - /** Closes the given object, ignoring IOExceptions. */ - public static void closeQuietly(Closeable closeable) { - try { - if (closeable != null) { - closeable.close(); - } - } catch (IOException e) { - logger.error("IOException should not have been thrown.", e); - } - } - - /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ - public static int nonNegativeHash(Object obj) { - if (obj == null) { return 0; } - int hash = obj.hashCode(); - return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; - } - - /** - * Convert the given string to a byte buffer. The resulting buffer can be - * converted back to the same string through {@link #bytesToString(ByteBuffer)}. - */ - public static ByteBuffer stringToBytes(String s) { - return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); - } - - /** - * Convert the given byte buffer to a string. The resulting string can be - * converted back to the same byte buffer through {@link #stringToBytes(String)}. - */ - public static String bytesToString(ByteBuffer b) { - return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); - } - - /* - * Delete a file or directory and its contents recursively. - * Don't follow directories if they are symlinks. - * Throws an exception if deletion is unsuccessful. - */ - public static void deleteRecursively(File file) throws IOException { - if (file == null) { return; } - - if (file.isDirectory() && !isSymlink(file)) { - IOException savedIOException = null; - for (File child : listFilesSafely(file)) { - try { - deleteRecursively(child); - } catch (IOException e) { - // In case of multiple exceptions, only last one will be thrown - savedIOException = e; - } - } - if (savedIOException != null) { - throw savedIOException; - } - } - - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); - } - } - - private static File[] listFilesSafely(File file) throws IOException { - if (file.exists()) { - File[] files = file.listFiles(); - if (files == null) { - throw new IOException("Failed to list files for dir: " + file); - } - return files; - } else { - return new File[0]; - } - } - - private static boolean isSymlink(File file) throws IOException { - Preconditions.checkNotNull(file); - File fileInCanonicalDir = null; - if (file.getParent() == null) { - fileInCanonicalDir = file; - } else { - fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); - } - return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); - } - - private static final ImmutableMap timeSuffixes = - ImmutableMap.builder() - .put("us", TimeUnit.MICROSECONDS) - .put("ms", TimeUnit.MILLISECONDS) - .put("s", TimeUnit.SECONDS) - .put("m", TimeUnit.MINUTES) - .put("min", TimeUnit.MINUTES) - .put("h", TimeUnit.HOURS) - .put("d", TimeUnit.DAYS) - .build(); - - private static final ImmutableMap byteSuffixes = - ImmutableMap.builder() - .put("b", ByteUnit.BYTE) - .put("k", ByteUnit.KiB) - .put("kb", ByteUnit.KiB) - .put("m", ByteUnit.MiB) - .put("mb", ByteUnit.MiB) - .put("g", ByteUnit.GiB) - .put("gb", ByteUnit.GiB) - .put("t", ByteUnit.TiB) - .put("tb", ByteUnit.TiB) - .put("p", ByteUnit.PiB) - .put("pb", ByteUnit.PiB) - .build(); - - /** - * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for - * internal use. If no suffix is provided a direct conversion is attempted. - */ - private static long parseTimeString(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); - - try { - Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (!m.matches()) { - throw new NumberFormatException("Failed to parse time string: " + str); - } - - long val = Long.parseLong(m.group(1)); - String suffix = m.group(2); - - // Check for invalid suffixes - if (suffix != null && !timeSuffixes.containsKey(suffix)) { - throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); - } - - // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); - } catch (NumberFormatException e) { - String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + - "E.g. 50s, 100ms, or 250us."; - - throw new NumberFormatException(timeError + "\n" + e.getMessage()); - } - } - - /** - * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If - * no suffix is provided, the passed number is assumed to be in ms. - */ - public static long timeStringAsMs(String str) { - return parseTimeString(str, TimeUnit.MILLISECONDS); - } - - /** - * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If - * no suffix is provided, the passed number is assumed to be in seconds. - */ - public static long timeStringAsSec(String str) { - return parseTimeString(str, TimeUnit.SECONDS); - } - - /** - * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is - * attempted. - */ - private static long parseByteString(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); - - try { - Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); - Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - - if (m.matches()) { - long val = Long.parseLong(m.group(1)); - String suffix = m.group(2); - - // Check for invalid suffixes - if (suffix != null && !byteSuffixes.containsKey(suffix)) { - throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); - } - - // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); - } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " - + fractionMatcher.group(1)); - } else { - throw new NumberFormatException("Failed to parse byte string: " + str); - } - - } catch (NumberFormatException e) { - String timeError = "Size must be specified as bytes (b), " + - "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + - "E.g. 50b, 100k, or 250m."; - - throw new NumberFormatException(timeError + "\n" + e.getMessage()); - } - } - - /** - * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for - * internal use. - * - * If no suffix is provided, the passed number is assumed to be in bytes. - */ - public static long byteStringAsBytes(String str) { - return parseByteString(str, ByteUnit.BYTE); - } - - /** - * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for - * internal use. - * - * If no suffix is provided, the passed number is assumed to be in kibibytes. - */ - public static long byteStringAsKb(String str) { - return parseByteString(str, ByteUnit.KiB); - } - - /** - * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for - * internal use. - * - * If no suffix is provided, the passed number is assumed to be in mebibytes. - */ - public static long byteStringAsMb(String str) { - return parseByteString(str, ByteUnit.MiB); - } - - /** - * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for - * internal use. - * - * If no suffix is provided, the passed number is assumed to be in gibibytes. - */ - public static long byteStringAsGb(String str) { - return parseByteString(str, ByteUnit.GiB); - } - - /** - * Returns a byte array with the buffer's contents, trying to avoid copying the data if - * possible. - */ - public static byte[] bufferToArray(ByteBuffer buffer) { - if (buffer.hasArray() && buffer.arrayOffset() == 0 && - buffer.array().length == buffer.remaining()) { - return buffer.array(); - } else { - byte[] bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - return bytes; - } - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java deleted file mode 100644 index 922c37a10e..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Based on LimitedInputStream.java from Google Guava - * - * Copyright (C) 2007 The Guava Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.io.FilterInputStream; -import java.io.IOException; -import java.io.InputStream; - -import com.google.common.base.Preconditions; - -/** - * Wraps a {@link InputStream}, limiting the number of bytes which can be read. - * - * This code is from Guava's 14.0 source code, because there is no compatible way to - * use this functionality in both a Guava 11 environment and a Guava >14 environment. - */ -public final class LimitedInputStream extends FilterInputStream { - private long left; - private long mark = -1; - - public LimitedInputStream(InputStream in, long limit) { - super(in); - Preconditions.checkNotNull(in); - Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); - left = limit; - } - @Override public int available() throws IOException { - return (int) Math.min(in.available(), left); - } - // it's okay to mark even if mark isn't supported, as reset won't work - @Override public synchronized void mark(int readLimit) { - in.mark(readLimit); - mark = left; - } - @Override public int read() throws IOException { - if (left == 0) { - return -1; - } - int result = in.read(); - if (result != -1) { - --left; - } - return result; - } - @Override public int read(byte[] b, int off, int len) throws IOException { - if (left == 0) { - return -1; - } - len = (int) Math.min(len, left); - int result = in.read(b, off, len); - if (result != -1) { - left -= result; - } - return result; - } - @Override public synchronized void reset() throws IOException { - if (!in.markSupported()) { - throw new IOException("Mark not supported"); - } - if (mark == -1) { - throw new IOException("Mark not set"); - } - in.reset(); - left = mark; - } - @Override public long skip(long n) throws IOException { - n = Math.min(n, left); - long skipped = in.skip(n); - left -= skipped; - return skipped; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java deleted file mode 100644 index 668d2356b9..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import com.google.common.collect.Maps; - -import java.util.Map; -import java.util.NoSuchElementException; - -/** ConfigProvider based on a Map (copied in the constructor). */ -public class MapConfigProvider extends ConfigProvider { - private final Map config; - - public MapConfigProvider(Map config) { - this.config = Maps.newHashMap(config); - } - - @Override - public String get(String name) { - String value = config.get(name); - if (value == null) { - throw new NoSuchElementException(name); - } - return value; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java deleted file mode 100644 index caa7260bc8..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.lang.reflect.Field; -import java.util.concurrent.ThreadFactory; - -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.channel.Channel; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.ServerChannel; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.epoll.EpollServerSocketChannel; -import io.netty.channel.epoll.EpollSocketChannel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; -import io.netty.util.internal.PlatformDependent; - -/** - * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. - */ -public class NettyUtils { - /** Creates a new ThreadFactory which prefixes each thread with the given name. */ - public static ThreadFactory createThreadFactory(String threadPoolPrefix) { - return new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat(threadPoolPrefix + "-%d") - .build(); - } - - /** Creates a Netty EventLoopGroup based on the IOMode. */ - public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - ThreadFactory threadFactory = createThreadFactory(threadPrefix); - - switch (mode) { - case NIO: - return new NioEventLoopGroup(numThreads, threadFactory); - case EPOLL: - return new EpollEventLoopGroup(numThreads, threadFactory); - default: - throw new IllegalArgumentException("Unknown io mode: " + mode); - } - } - - /** Returns the correct (client) SocketChannel class based on IOMode. */ - public static Class getClientChannelClass(IOMode mode) { - switch (mode) { - case NIO: - return NioSocketChannel.class; - case EPOLL: - return EpollSocketChannel.class; - default: - throw new IllegalArgumentException("Unknown io mode: " + mode); - } - } - - /** Returns the correct ServerSocketChannel class based on IOMode. */ - public static Class getServerChannelClass(IOMode mode) { - switch (mode) { - case NIO: - return NioServerSocketChannel.class; - case EPOLL: - return EpollServerSocketChannel.class; - default: - throw new IllegalArgumentException("Unknown io mode: " + mode); - } - } - - /** - * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. - * This is used before all decoders. - */ - public static TransportFrameDecoder createFrameDecoder() { - return new TransportFrameDecoder(); - } - - /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ - public static String getRemoteAddress(Channel channel) { - if (channel != null && channel.remoteAddress() != null) { - return channel.remoteAddress().toString(); - } - return ""; - } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread, - * but released by the executor thread rather than the event loop thread. Those thread-local - * caches actually delay the recycling of buffers, leading to larger memory usage. - */ - public static PooledByteBufAllocator createPooledByteBufAllocator( - boolean allowDirectBufs, - boolean allowCache, - int numCores) { - if (numCores == 0) { - numCores = Runtime.getRuntime().availableProcessors(); - } - return new PooledByteBufAllocator( - allowDirectBufs && PlatformDependent.directBufferPreferred(), - Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), - Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 - ); - } - - /** Used to get defaults from Netty's private static fields. */ - private static int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java deleted file mode 100644 index 5f20b70678..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.util.NoSuchElementException; - -import org.apache.spark.network.util.ConfigProvider; - -/** Uses System properties to obtain config values. */ -public class SystemPropertyConfigProvider extends ConfigProvider { - @Override - public String get(String name) { - String value = System.getProperty(name); - if (value == null) { - throw new NoSuchElementException(name); - } - return value; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java deleted file mode 100644 index 9f030da2b3..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import com.google.common.primitives.Ints; - -/** - * A central location that tracks all the settings we expose to users. - */ -public class TransportConf { - - private final String SPARK_NETWORK_IO_MODE_KEY; - private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; - private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; - private final String SPARK_NETWORK_IO_BACKLOG_KEY; - private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; - private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; - private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; - private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; - private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; - private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; - private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; - private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; - private final String SPARK_NETWORK_IO_LAZYFD_KEY; - - private final ConfigProvider conf; - - private final String module; - - public TransportConf(String module, ConfigProvider conf) { - this.module = module; - this.conf = conf; - SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); - SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); - SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); - SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); - SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); - SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); - SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); - SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); - SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); - SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); - SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); - SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); - SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); - } - - private String getConfKey(String suffix) { - return "spark." + module + "." + suffix; - } - - /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } - - /** If true, we will prefer allocating off-heap byte buffers within Netty. */ - public boolean preferDirectBufs() { - return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); - } - - /** Connect timeout in milliseconds. Default 120 secs. */ - public int connectionTimeoutMs() { - long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( - conf.get("spark.network.timeout", "120s")); - long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; - return (int) defaultTimeoutMs; - } - - /** Number of concurrent connections between two nodes for fetching data. */ - public int numConnectionsPerPeer() { - return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); - } - - /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } - - /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } - - /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } - - /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } - - /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; - } - - /** - * Max number of times we will try IO exceptions (such as connection timeouts) per request. - * If set to 0, we will not do any retries. - */ - public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } - - /** - * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. - * Only relevant if maxIORetries > 0. - */ - public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; - } - - /** - * Minimum size of a block that we should start using memory map rather than reading in through - * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, - * memory mapping has high overhead for blocks close to or below the page size of the OS. - */ - public int memoryMapBytes() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("spark.storage.memoryMapThreshold", "2m"))); - } - - /** - * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are - * created only when data is going to be transferred. This can reduce the number of open files. - */ - public boolean lazyFileDescriptor() { - return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); - } - - /** - * Maximum number of retries when binding to a port before giving up. - */ - public int portMaxRetries() { - return conf.getInt("spark.port.maxRetries", 16); - } - - /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. - */ - public int maxSaslEncryptedBlockSize() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); - } - - /** - * Whether the server should enforce encryption on SASL-authenticated connections. - */ - public boolean saslServerAlwaysEncrypt() { - return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java deleted file mode 100644 index a466c72915..0000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.util.Iterator; -import java.util.LinkedList; - -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; - -/** - * A customized frame decoder that allows intercepting raw data. - *

- * This behaves like Netty's frame decoder (with harcoded parameters that match this library's - * needs), except it allows an interceptor to be installed to read data directly before it's - * framed. - *

- * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's - * decoded, instead of building as many frames as the current buffer allows and dispatching - * all of them. This allows a child handler to install an interceptor if needed. - *

- * If an interceptor is installed, framing stops, and data is instead fed directly to the - * interceptor. When the interceptor indicates that it doesn't need to read any more data, - * framing resumes. Interceptors should not hold references to the data buffers provided - * to their handle() method. - */ -public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { - - public static final String HANDLER_NAME = "frameDecoder"; - private static final int LENGTH_SIZE = 8; - private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; - private static final int UNKNOWN_FRAME_SIZE = -1; - - private final LinkedList buffers = new LinkedList<>(); - private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - - private long totalSize = 0; - private long nextFrameSize = UNKNOWN_FRAME_SIZE; - private volatile Interceptor interceptor; - - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf in = (ByteBuf) data; - buffers.add(in); - totalSize += in.readableBytes(); - - while (!buffers.isEmpty()) { - // First, feed the interceptor, and if it's still, active, try again. - if (interceptor != null) { - ByteBuf first = buffers.getFirst(); - int available = first.readableBytes(); - if (feedInterceptor(first)) { - assert !first.isReadable() : "Interceptor still active but buffer has data."; - } - - int read = available - first.readableBytes(); - if (read == available) { - buffers.removeFirst().release(); - } - totalSize -= read; - } else { - // Interceptor is not active, so try to decode one frame. - ByteBuf frame = decodeNext(); - if (frame == null) { - break; - } - ctx.fireChannelRead(frame); - } - } - } - - private long decodeFrameSize() { - if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { - return nextFrameSize; - } - - // We know there's enough data. If the first buffer contains all the data, great. Otherwise, - // hold the bytes for the frame length in a composite buffer until we have enough data to read - // the frame size. Normally, it should be rare to need more than one buffer to read the frame - // size. - ByteBuf first = buffers.getFirst(); - if (first.readableBytes() >= LENGTH_SIZE) { - nextFrameSize = first.readLong() - LENGTH_SIZE; - totalSize -= LENGTH_SIZE; - if (!first.isReadable()) { - buffers.removeFirst().release(); - } - return nextFrameSize; - } - - while (frameLenBuf.readableBytes() < LENGTH_SIZE) { - ByteBuf next = buffers.getFirst(); - int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); - frameLenBuf.writeBytes(next, toRead); - if (!next.isReadable()) { - buffers.removeFirst().release(); - } - } - - nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; - totalSize -= LENGTH_SIZE; - frameLenBuf.clear(); - return nextFrameSize; - } - - private ByteBuf decodeNext() throws Exception { - long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { - return null; - } - - // Reset size for next frame. - nextFrameSize = UNKNOWN_FRAME_SIZE; - - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); - - // If the first buffer holds the entire frame, return it. - int remaining = (int) frameSize; - if (buffers.getFirst().readableBytes() >= remaining) { - return nextBufferForFrame(remaining); - } - - // Otherwise, create a composite buffer. - CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); - while (remaining > 0) { - ByteBuf next = nextBufferForFrame(remaining); - remaining -= next.readableBytes(); - frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); - } - assert remaining == 0; - return frame; - } - - /** - * Takes the first buffer in the internal list, and either adjust it to fit in the frame - * (by taking a slice out of it) or remove it from the internal list. - */ - private ByteBuf nextBufferForFrame(int bytesToRead) { - ByteBuf buf = buffers.getFirst(); - ByteBuf frame; - - if (buf.readableBytes() > bytesToRead) { - frame = buf.retain().readSlice(bytesToRead); - totalSize -= bytesToRead; - } else { - frame = buf; - buffers.removeFirst(); - totalSize -= frame.readableBytes(); - } - - return frame; - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - for (ByteBuf b : buffers) { - b.release(); - } - if (interceptor != null) { - interceptor.channelInactive(); - } - frameLenBuf.release(); - super.channelInactive(ctx); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (interceptor != null) { - interceptor.exceptionCaught(cause); - } - super.exceptionCaught(ctx, cause); - } - - public void setInterceptor(Interceptor interceptor) { - Preconditions.checkState(this.interceptor == null, "Already have an interceptor."); - this.interceptor = interceptor; - } - - /** - * @return Whether the interceptor is still active after processing the data. - */ - private boolean feedInterceptor(ByteBuf buf) throws Exception { - if (interceptor != null && !interceptor.handle(buf)) { - interceptor = null; - } - return interceptor != null; - } - - public static interface Interceptor { - - /** - * Handles data received from the remote end. - * - * @param data Buffer containing data. - * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor. - */ - boolean handle(ByteBuf data) throws Exception; - - /** Called if an exception is thrown in the channel pipeline. */ - void exceptionCaught(Throwable cause) throws Exception; - - /** Called if the channel is closed and the interceptor is still installed. */ - void channelInactive() throws Exception; - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java deleted file mode 100644 index 70c849d60e..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.io.File; -import java.io.RandomAccessFile; -import java.nio.ByteBuffer; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Random; -import java.util.Set; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import com.google.common.io.Closeables; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class ChunkFetchIntegrationSuite { - static final long STREAM_ID = 1; - static final int BUFFER_CHUNK_INDEX = 0; - static final int FILE_CHUNK_INDEX = 1; - - static TransportServer server; - static TransportClientFactory clientFactory; - static StreamManager streamManager; - static File testFile; - - static ManagedBuffer bufferChunk; - static ManagedBuffer fileChunk; - - private TransportConf transportConf; - - @BeforeClass - public static void setUp() throws Exception { - int bufSize = 100000; - final ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); - } - buf.flip(); - bufferChunk = new NioManagedBuffer(buf); - - testFile = File.createTempFile("shuffle-test-file", "txt"); - testFile.deleteOnExit(); - RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - boolean shouldSuppressIOException = true; - try { - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - shouldSuppressIOException = false; - } finally { - Closeables.close(fp, shouldSuppressIOException); - } - - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); - - streamManager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - assertEquals(STREAM_ID, streamId); - if (chunkIndex == BUFFER_CHUNK_INDEX) { - return new NioManagedBuffer(buf); - } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); - } else { - throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); - } - } - }; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } - }; - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - } - - @AfterClass - public static void tearDown() { - bufferChunk.release(); - server.close(); - clientFactory.close(); - testFile.delete(); - } - - class FetchResult { - public Set successChunks; - public Set failedChunks; - public List buffers; - - public void releaseBuffers() { - for (ManagedBuffer buffer : buffers) { - buffer.release(); - } - } - } - - private FetchResult fetchChunks(List chunkIndices) throws Exception { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - final Semaphore sem = new Semaphore(0); - - final FetchResult res = new FetchResult(); - res.successChunks = Collections.synchronizedSet(new HashSet()); - res.failedChunks = Collections.synchronizedSet(new HashSet()); - res.buffers = Collections.synchronizedList(new LinkedList()); - - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - buffer.retain(); - res.successChunks.add(chunkIndex); - res.buffers.add(buffer); - sem.release(); - } - - @Override - public void onFailure(int chunkIndex, Throwable e) { - res.failedChunks.add(chunkIndex); - sem.release(); - } - }; - - for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); - } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } - client.close(); - return res; - } - - @Test - public void fetchBufferChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); - assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); - res.releaseBuffers(); - } - - @Test - public void fetchFileChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); - assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); - res.releaseBuffers(); - } - - @Test - public void fetchNonExistentChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(12345)); - assertTrue(res.successChunks.isEmpty()); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertTrue(res.buffers.isEmpty()); - } - - @Test - public void fetchBothChunks() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); - assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); - res.releaseBuffers(); - } - - @Test - public void fetchChunkAndNonExistent() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); - res.releaseBuffers(); - } - - private void assertBufferListsEqual(List list0, List list1) - throws Exception { - assertEquals(list0.size(), list1.size()); - for (int i = 0; i < list0.size(); i ++) { - assertBuffersEqual(list0.get(i), list1.get(i)); - } - } - - private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { - ByteBuffer nio0 = buffer0.nioByteBuffer(); - ByteBuffer nio1 = buffer1.nioByteBuffer(); - - int len = nio0.remaining(); - assertEquals(nio0.remaining(), nio1.remaining()); - for (int i = 0; i < len; i ++) { - assertEquals(nio0.get(), nio1.get()); - } - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java deleted file mode 100644 index 6c8dd742f4..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.util.List; - -import com.google.common.primitives.Ints; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.FileRegion; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.MessageToMessageEncoder; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.MessageDecoder; -import org.apache.spark.network.protocol.MessageEncoder; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.NettyUtils; - -public class ProtocolSuite { - private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); - serverChannel.writeOutbound(msg); - - EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); - - while (!serverChannel.outboundMessages().isEmpty()) { - clientChannel.writeInbound(serverChannel.readOutbound()); - } - - assertEquals(1, clientChannel.inboundMessages().size()); - assertEquals(msg, clientChannel.readInbound()); - } - - private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); - clientChannel.writeOutbound(msg); - - EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); - - while (!clientChannel.outboundMessages().isEmpty()) { - serverChannel.writeInbound(clientChannel.readOutbound()); - } - - assertEquals(1, serverChannel.inboundMessages().size()); - assertEquals(msg, serverChannel.readInbound()); - } - - @Test - public void requests() { - testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); - testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); - testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); - } - - @Test - public void responses() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); - testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); - testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); - testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); - testServerToClient(new RpcFailure(0, "this is an error")); - testServerToClient(new RpcFailure(0, "")); - // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the - // channel and cannot be tested like this. - testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0))); - testServerToClient(new StreamFailure("anId", "this is an error")); - } - - /** - * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer - * bytes, but messages, so this is needed so that the frame decoder on the receiving side can - * understand what MessageWithHeader actually contains. - */ - private static class FileRegionEncoder extends MessageToMessageEncoder { - - @Override - public void encode(ChannelHandlerContext ctx, FileRegion in, List out) - throws Exception { - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); - while (in.transfered() < in.count()) { - in.transferTo(channel, in.transfered()); - } - out.add(Unpooled.wrappedBuffer(channel.getData())); - } - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java deleted file mode 100644 index f9b5bf96d6..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import com.google.common.collect.Maps; -import com.google.common.util.concurrent.Uninterruptibles; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; -import org.junit.*; -import static org.junit.Assert.*; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.*; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -/** - * Suite which ensures that requests that go without a response for the network timeout period are - * failed, and the connection closed. - * - * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, - * to ensure stability in different test environments. - */ -public class RequestTimeoutIntegrationSuite { - - private TransportServer server; - private TransportClientFactory clientFactory; - - private StreamManager defaultManager; - private TransportConf conf; - - // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. - private final int FOREVER = 60 * 1000; - - @Before - public void setUp() throws Exception { - Map configMap = Maps.newHashMap(); - configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); - - defaultManager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - throw new UnsupportedOperationException(); - } - }; - } - - @After - public void tearDown() { - if (server != null) { - server.close(); - } - if (clientFactory != null) { - clientFactory.close(); - } - } - - // Basic suite: First request completes quickly, and second waits for longer than network timeout. - @Test - public void timeoutInactiveRequests() throws Exception { - final Semaphore semaphore = new Semaphore(1); - final int responseSize = 16; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - try { - semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(ByteBuffer.allocate(responseSize)); - } catch (InterruptedException e) { - // do nothing - } - } - - @Override - public StreamManager getStreamManager() { - return defaultManager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - // First completes quickly (semaphore starts at 1). - TestCallback callback0 = new TestCallback(); - synchronized (callback0) { - client.sendRpc(ByteBuffer.allocate(0), callback0); - callback0.wait(FOREVER); - assertEquals(responseSize, callback0.successLength); - } - - // Second times out after 2 seconds, with slack. Must be IOException. - TestCallback callback1 = new TestCallback(); - synchronized (callback1) { - client.sendRpc(ByteBuffer.allocate(0), callback1); - callback1.wait(4 * 1000); - assert (callback1.failure != null); - assert (callback1.failure instanceof IOException); - } - semaphore.release(); - } - - // A timeout will cause the connection to be closed, invalidating the current TransportClient. - // It should be the case that requesting a client from the factory produces a new, valid one. - @Test - public void timeoutCleanlyClosesClient() throws Exception { - final Semaphore semaphore = new Semaphore(0); - final int responseSize = 16; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - try { - semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(ByteBuffer.allocate(responseSize)); - } catch (InterruptedException e) { - // do nothing - } - } - - @Override - public StreamManager getStreamManager() { - return defaultManager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - - // First request should eventually fail. - TransportClient client0 = - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - TestCallback callback0 = new TestCallback(); - synchronized (callback0) { - client0.sendRpc(ByteBuffer.allocate(0), callback0); - callback0.wait(FOREVER); - assert (callback0.failure instanceof IOException); - assert (!client0.isActive()); - } - - // Increment the semaphore and the second request should succeed quickly. - semaphore.release(2); - TransportClient client1 = - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - TestCallback callback1 = new TestCallback(); - synchronized (callback1) { - client1.sendRpc(ByteBuffer.allocate(0), callback1); - callback1.wait(FOREVER); - assertEquals(responseSize, callback1.successLength); - assertNull(callback1.failure); - } - } - - // The timeout is relative to the LAST request sent, which is kinda weird, but still. - // This test also makes sure the timeout works for Fetch requests as well as RPCs. - @Test - public void furtherRequestsDelay() throws Exception { - final byte[] response = new byte[16]; - final StreamManager manager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); - return new NioManagedBuffer(ByteBuffer.wrap(response)); - } - }; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - @Override - public StreamManager getStreamManager() { - return manager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - // Send one request, which will eventually fail. - TestCallback callback0 = new TestCallback(); - client.fetchChunk(0, 0, callback0); - Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); - - // Send a second request before the first has failed. - TestCallback callback1 = new TestCallback(); - client.fetchChunk(0, 1, callback1); - Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); - - synchronized (callback0) { - // not complete yet, but should complete soon - assertEquals(-1, callback0.successLength); - assertNull(callback0.failure); - callback0.wait(2 * 1000); - assertTrue(callback0.failure instanceof IOException); - } - - synchronized (callback1) { - // failed at same time as previous - assert (callback0.failure instanceof IOException); - } - } - - /** - * Callback which sets 'success' or 'failure' on completion. - * Additionally notifies all waiters on this callback when invoked. - */ - class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - - int successLength = -1; - Throwable failure; - - @Override - public void onSuccess(ByteBuffer response) { - synchronized(this) { - successLength = response.remaining(); - this.notifyAll(); - } - } - - @Override - public void onFailure(Throwable e) { - synchronized(this) { - failure = e; - this.notifyAll(); - } - } - - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - synchronized(this) { - try { - successLength = buffer.nioByteBuffer().remaining(); - this.notifyAll(); - } catch (IOException e) { - // weird - } - } - } - - @Override - public void onFailure(int chunkIndex, Throwable e) { - synchronized(this) { - failure = e; - this.notifyAll(); - } - } - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java deleted file mode 100644 index 9e9be98c14..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -import com.google.common.collect.Sets; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class RpcIntegrationSuite { - static TransportServer server; - static TransportClientFactory clientFactory; - static RpcHandler rpcHandler; - static List oneWayMsgs; - - @BeforeClass - public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - rpcHandler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - String msg = JavaUtils.bytesToString(message); - String[] parts = msg.split("/"); - if (parts[0].equals("hello")) { - callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); - } else if (parts[0].equals("return error")) { - callback.onFailure(new RuntimeException("Returned: " + parts[1])); - } else if (parts[0].equals("throw error")) { - throw new RuntimeException("Thrown: " + parts[1]); - } - } - - @Override - public void receive(TransportClient client, ByteBuffer message) { - oneWayMsgs.add(JavaUtils.bytesToString(message)); - } - - @Override - public StreamManager getStreamManager() { return new OneForOneStreamManager(); } - }; - TransportContext context = new TransportContext(conf, rpcHandler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - oneWayMsgs = new ArrayList<>(); - } - - @AfterClass - public static void tearDown() { - server.close(); - clientFactory.close(); - } - - class RpcResult { - public Set successMessages; - public Set errorMessages; - } - - private RpcResult sendRPC(String ... commands) throws Exception { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - final Semaphore sem = new Semaphore(0); - - final RpcResult res = new RpcResult(); - res.successMessages = Collections.synchronizedSet(new HashSet()); - res.errorMessages = Collections.synchronizedSet(new HashSet()); - - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer message) { - String response = JavaUtils.bytesToString(message); - res.successMessages.add(response); - sem.release(); - } - - @Override - public void onFailure(Throwable e) { - res.errorMessages.add(e.getMessage()); - sem.release(); - } - }; - - for (String command : commands) { - client.sendRpc(JavaUtils.stringToBytes(command), callback); - } - - if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } - client.close(); - return res; - } - - @Test - public void singleRPC() throws Exception { - RpcResult res = sendRPC("hello/Aaron"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); - assertTrue(res.errorMessages.isEmpty()); - } - - @Test - public void doubleRPC() throws Exception { - RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); - assertTrue(res.errorMessages.isEmpty()); - } - - @Test - public void returnErrorRPC() throws Exception { - RpcResult res = sendRPC("return error/OK"); - assertTrue(res.successMessages.isEmpty()); - assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK")); - } - - @Test - public void throwErrorRPC() throws Exception { - RpcResult res = sendRPC("throw error/uh-oh"); - assertTrue(res.successMessages.isEmpty()); - assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh")); - } - - @Test - public void doubleTrouble() throws Exception { - RpcResult res = sendRPC("return error/OK", "throw error/uh-oh"); - assertTrue(res.successMessages.isEmpty()); - assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh")); - } - - @Test - public void sendSuccessAndFailure() throws Exception { - RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); - assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); - } - - @Test - public void sendOneWayMessage() throws Exception { - final String message = "no reply"; - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - client.send(JavaUtils.stringToBytes(message)); - assertEquals(0, client.getHandler().numOutstandingRequests()); - - // Make sure the message arrives. - long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); - while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { - TimeUnit.MILLISECONDS.sleep(10); - } - - assertEquals(1, oneWayMsgs.size()); - assertEquals(message, oneWayMsgs.get(0)); - } finally { - client.close(); - } - } - - private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); - - Set remainingErrors = Sets.newHashSet(errors); - for (String contain : contains) { - Iterator it = remainingErrors.iterator(); - boolean foundMatch = false; - while (it.hasNext()) { - if (it.next().contains(contain)) { - it.remove(); - foundMatch = true; - break; - } - } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); - } - - assertTrue(remainingErrors.isEmpty()); - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java deleted file mode 100644 index 9c49556927..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; -import java.util.concurrent.Executors; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; - -import com.google.common.io.Files; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import static org.junit.Assert.*; - -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; - - private static TransportServer server; - private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; - - private static ByteBuffer createBuffer(int bufSize) { - ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); - } - buf.flip(); - return buf; - } - - @BeforeClass - public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } - - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - final StreamManager streamManager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } - } - }; - RpcHandler handler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } - }; - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - } - - @AfterClass - public static void tearDown() { - server.close(); - clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } - } - - @Test - public void testZeroLengthStream() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } finally { - client.close(); - } - } - - @Test - public void testSingleStream() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } finally { - client.close(); - } - } - - @Test - public void testMultipleStreams() throws Throwable { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - for (int i = 0; i < 20; i++) { - StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], - TimeUnit.SECONDS.toMillis(5)); - task.run(); - task.check(); - } - } finally { - client.close(); - } - } - - @Test - public void testConcurrentStreams() throws Throwable { - ExecutorService executor = Executors.newFixedThreadPool(20); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - try { - List tasks = new ArrayList<>(); - for (int i = 0; i < 20; i++) { - StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], - TimeUnit.SECONDS.toMillis(20)); - tasks.add(task); - executor.submit(task); - } - - executor.shutdown(); - assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS)); - for (StreamTask task : tasks) { - task.check(); - } - } finally { - executor.shutdownNow(); - client.close(); - } - } - - private static class StreamTask implements Runnable { - - private final TransportClient client; - private final String streamId; - private final long timeoutMs; - private Throwable error; - - StreamTask(TransportClient client, String streamId, long timeoutMs) { - this.client = client; - this.streamId = streamId; - this.timeoutMs = timeoutMs; - } - - @Override - public void run() { - ByteBuffer srcBuffer = null; - OutputStream out = null; - File outFile = null; - try { - ByteArrayOutputStream baos = null; - - switch (streamId) { - case "largeBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = largeBuffer; - break; - case "smallBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = smallBuffer; - break; - case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); - out = new FileOutputStream(outFile); - break; - case "emptyBuffer": - baos = new ByteArrayOutputStream(); - out = baos; - srcBuffer = emptyBuffer; - break; - default: - throw new IllegalArgumentException(streamId); - } - - TestCallback callback = new TestCallback(out); - client.stream(streamId, callback); - waitForCompletion(callback); - - if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); - } else { - ByteBuffer base; - synchronized (srcBuffer) { - base = srcBuffer.duplicate(); - } - byte[] result = baos.toByteArray(); - byte[] expected = new byte[base.remaining()]; - base.get(expected); - assertEquals(expected.length, result.length); - assertTrue("buffers don't match", Arrays.equals(expected, result)); - } - } catch (Throwable t) { - error = t; - } finally { - if (out != null) { - try { - out.close(); - } catch (Exception e) { - // ignore. - } - } - if (outFile != null) { - outFile.delete(); - } - } - } - - public void check() throws Throwable { - if (error != null) { - throw error; - } - } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - - } - - private static class TestCallback implements StreamCallback { - - private final OutputStream out; - public volatile boolean completed; - public volatile Throwable error; - - TestCallback(OutputStream out) { - this.out = out; - this.completed = false; - } - - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - byte[] tmp = new byte[buf.remaining()]; - buf.get(tmp); - out.write(tmp); - } - - @Override - public void onComplete(String streamId) throws IOException { - out.close(); - synchronized (this) { - completed = true; - notifyAll(); - } - } - - @Override - public void onFailure(String streamId, Throwable cause) { - error = cause; - synchronized (this) { - completed = true; - notifyAll(); - } - } - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java deleted file mode 100644 index 83c90f9eff..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; - -import com.google.common.base.Preconditions; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). - * - * Used for testing. - */ -public class TestManagedBuffer extends ManagedBuffer { - - private final int len; - private NettyManagedBuffer underlying; - - public TestManagedBuffer(int len) { - Preconditions.checkArgument(len <= Byte.MAX_VALUE); - this.len = len; - byte[] byteArray = new byte[len]; - for (int i = 0; i < len; i ++) { - byteArray[i] = (byte) i; - } - this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); - } - - - @Override - public long size() { - return underlying.size(); - } - - @Override - public ByteBuffer nioByteBuffer() throws IOException { - return underlying.nioByteBuffer(); - } - - @Override - public InputStream createInputStream() throws IOException { - return underlying.createInputStream(); - } - - @Override - public ManagedBuffer retain() { - underlying.retain(); - return this; - } - - @Override - public ManagedBuffer release() { - underlying.release(); - return this; - } - - @Override - public Object convertToNetty() throws IOException { - return underlying.convertToNetty(); - } - - @Override - public int hashCode() { - return underlying.hashCode(); - } - - @Override - public boolean equals(Object other) { - if (other instanceof ManagedBuffer) { - try { - ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); - if (nioBuf.remaining() != len) { - return false; - } else { - for (int i = 0; i < len; i ++) { - if (nioBuf.get() != i) { - return false; - } - } - return true; - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - return false; - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java deleted file mode 100644 index 56a2b805f1..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/TestUtils.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.net.InetAddress; - -public class TestUtils { - public static String getLocalHost() { - try { - return InetAddress.getLocalHost().getHostAddress(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java deleted file mode 100644 index dac7d4a5b0..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.io.IOException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; - -import com.google.common.collect.Maps; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class TransportClientFactorySuite { - private TransportConf conf; - private TransportContext context; - private TransportServer server1; - private TransportServer server2; - - @Before - public void setUp() { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new TransportContext(conf, rpcHandler); - server1 = context.createServer(); - server2 = context.createServer(); - } - - @After - public void tearDown() { - JavaUtils.closeQuietly(server1); - JavaUtils.closeQuietly(server2); - } - - /** - * Request a bunch of clients to a single server to test - * we create up to maxConnections of clients. - * - * If concurrent is true, create multiple threads to create clients in parallel. - */ - private void testClientReuse(final int maxConnections, boolean concurrent) - throws IOException, InterruptedException { - - Map configMap = Maps.newHashMap(); - configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); - - RpcHandler rpcHandler = new NoOpRpcHandler(); - TransportContext context = new TransportContext(conf, rpcHandler); - final TransportClientFactory factory = context.createClientFactory(); - final Set clients = Collections.synchronizedSet( - new HashSet()); - - final AtomicInteger failed = new AtomicInteger(); - Thread[] attempts = new Thread[maxConnections * 10]; - - // Launch a bunch of threads to create new clients. - for (int i = 0; i < attempts.length; i++) { - attempts[i] = new Thread() { - @Override - public void run() { - try { - TransportClient client = - factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assert (client.isActive()); - clients.add(client); - } catch (IOException e) { - failed.incrementAndGet(); - } - } - }; - - if (concurrent) { - attempts[i].start(); - } else { - attempts[i].run(); - } - } - - // Wait until all the threads complete. - for (int i = 0; i < attempts.length; i++) { - attempts[i].join(); - } - - assert(failed.get() == 0); - assert(clients.size() == maxConnections); - - for (TransportClient client : clients) { - client.close(); - } - } - - @Test - public void reuseClientsUpToConfigVariable() throws Exception { - testClientReuse(1, false); - testClientReuse(2, false); - testClientReuse(3, false); - testClientReuse(4, false); - } - - @Test - public void reuseClientsUpToConfigVariableConcurrent() throws Exception { - testClientReuse(1, true); - testClientReuse(2, true); - testClientReuse(3, true); - testClientReuse(4, true); - } - - @Test - public void returnDifferentClientsForDifferentServers() throws IOException { - TransportClientFactory factory = context.createClientFactory(); - TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); - assertTrue(c1.isActive()); - assertTrue(c2.isActive()); - assertTrue(c1 != c2); - factory.close(); - } - - @Test - public void neverReturnInactiveClients() throws IOException, InterruptedException { - TransportClientFactory factory = context.createClientFactory(); - TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - c1.close(); - - long start = System.currentTimeMillis(); - while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { - Thread.sleep(10); - } - assertFalse(c1.isActive()); - - TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertFalse(c1 == c2); - assertTrue(c2.isActive()); - factory.close(); - } - - @Test - public void closeBlockClientsWithFactory() throws IOException { - TransportClientFactory factory = context.createClientFactory(); - TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); - assertTrue(c1.isActive()); - assertTrue(c2.isActive()); - factory.close(); - assertFalse(c1.isActive()); - assertFalse(c2.isActive()); - } - - @Test - public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { - - @Override - public String get(String name) { - if ("spark.shuffle.io.connectionTimeout".equals(name)) { - // We should make sure there is enough time for us to observe the channel is active - return "1s"; - } - String value = System.getProperty(name); - if (value == null) { - throw new NoSuchElementException(name); - } - return value; - } - }); - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); - TransportClientFactory factory = context.createClientFactory(); - try { - TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertTrue(c1.isActive()); - long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds - while (c1.isActive() && System.currentTimeMillis() < expiredTime) { - Thread.sleep(10); - } - assertFalse(c1.isActive()); - } finally { - factory.close(); - } - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java deleted file mode 100644 index 128f7cba74..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import java.nio.ByteBuffer; - -import io.netty.channel.Channel; -import io.netty.channel.local.LocalChannel; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.util.TransportFrameDecoder; - -public class TransportResponseHandlerSuite { - @Test - public void handleSuccessfulFetch() throws Exception { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); - - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleFailedFetch() throws Exception { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); - verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void clearAllOutstandingRequests() throws Exception { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); - assertEquals(3, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); - handler.exceptionCaught(new Exception("duh duh duhhhh")); - - // should fail both b2 and b3 - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); - verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleSuccessfulRPC() throws Exception { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - RpcResponseCallback callback = mock(RpcResponseCallback.class); - handler.addRpcRequest(12345, callback); - assertEquals(1, handler.numOutstandingRequests()); - - // This response should be ignored. - handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); - assertEquals(1, handler.numOutstandingRequests()); - - ByteBuffer resp = ByteBuffer.allocate(10); - handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); - verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleFailedRPC() throws Exception { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - RpcResponseCallback callback = mock(RpcResponseCallback.class); - handler.addRpcRequest(12345, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcFailure(12345, "oh no")); - verify(callback, times(1)).onFailure((Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void testActiveStreams() throws Exception { - Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); - TransportResponseHandler handler = new TransportResponseHandler(c); - - StreamResponse response = new StreamResponse("stream", 1234L, null); - StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback(cb); - assertEquals(1, handler.numOutstandingRequests()); - handler.handle(response); - assertEquals(1, handler.numOutstandingRequests()); - handler.deactivateStream(); - assertEquals(0, handler.numOutstandingRequests()); - - StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback(cb); - assertEquals(1, handler.numOutstandingRequests()); - handler.handle(failure); - assertEquals(0, handler.numOutstandingRequests()); - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java deleted file mode 100644 index fbbe4b7014..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; -import org.junit.Test; -import org.mockito.Mockito; - -import static org.junit.Assert.*; - -import org.apache.spark.network.TestManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; -import org.apache.spark.network.util.ByteArrayWritableChannel; - -public class MessageWithHeaderSuite { - - @Test - public void testSingleWrite() throws Exception { - testFileRegionBody(8, 8); - } - - @Test - public void testShortWrite() throws Exception { - testFileRegionBody(8, 1); - } - - @Test - public void testByteBufBody() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); - assertEquals(1, header.refCnt()); - assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); - ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); - - Object body = managedBuf.convertToNetty(); - assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); - assertEquals(1, header.refCnt()); - - MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); - ByteBuf result = doWrite(msg, 1); - assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); - assertEquals(84, result.readLong()); - - assert(msg.release()); - assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); - assertEquals(0, header.refCnt()); - } - - @Test - public void testDeallocateReleasesManagedBuffer() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); - ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); - assertEquals(2, body.refCnt()); - MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); - assert(msg.release()); - Mockito.verify(managedBuf, Mockito.times(1)).release(); - assertEquals(0, body.refCnt()); - } - - private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { - ByteBuf header = Unpooled.copyLong(42); - int headerLength = header.readableBytes(); - TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); - MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); - - ByteBuf result = doWrite(msg, totalWrites / writesPerCall); - assertEquals(headerLength + region.count(), result.readableBytes()); - assertEquals(42, result.readLong()); - for (long i = 0; i < 8; i++) { - assertEquals(i, result.readLong()); - } - assert(msg.release()); - } - - private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { - int writes = 0; - ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); - while (msg.transfered() < msg.count()) { - msg.transferTo(channel, msg.transfered()); - writes++; - } - assertTrue("Not enough writes!", minExpectedWrites <= writes); - return Unpooled.wrappedBuffer(channel.getData()); - } - - private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { - - private final int writeCount; - private final int writesPerCall; - private int written; - - TestFileRegion(int totalWrites, int writesPerCall) { - this.writeCount = totalWrites; - this.writesPerCall = writesPerCall; - } - - @Override - public long count() { - return 8 * writeCount; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transfered() { - return 8 * written; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - for (int i = 0; i < writesPerCall; i++) { - ByteBuf buf = Unpooled.copyLong((position / 8) + i); - ByteBuffer nio = buf.nioBuffer(); - while (nio.remaining() > 0) { - target.write(nio); - } - buf.release(); - written++; - } - return 8 * writesPerCall; - } - - @Override - protected void deallocate() { - } - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java deleted file mode 100644 index 045773317a..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ /dev/null @@ -1,476 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -import java.io.File; -import java.lang.reflect.Method; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.List; -import java.util.Random; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import javax.security.sasl.SaslException; - -import com.google.common.collect.Lists; -import com.google.common.io.ByteStreams; -import com.google.common.io.Files; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import org.apache.spark.network.TestUtils; -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -/** - * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. - */ -public class SparkSaslSuite { - - /** Provides a secret key holder which returns secret key == appId */ - private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { - @Override - public String getSaslUser(String appId) { - return "user"; - } - - @Override - public String getSecretKey(String appId) { - return appId; - } - }; - - @Test - public void testMatching() { - SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false); - SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false); - - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - - byte[] clientMessage = client.firstToken(); - - while (!client.isComplete()) { - clientMessage = client.response(server.response(clientMessage)); - } - assertTrue(server.isComplete()); - - // Disposal should invalidate - server.dispose(); - assertFalse(server.isComplete()); - client.dispose(); - assertFalse(client.isComplete()); - } - - @Test - public void testNonMatching() { - SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false); - SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false); - - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - - byte[] clientMessage = client.firstToken(); - - try { - while (!client.isComplete()) { - clientMessage = client.response(server.response(clientMessage)); - } - fail("Should not have completed"); - } catch (Exception e) { - assertTrue(e.getMessage().contains("Mismatched response")); - assertFalse(client.isComplete()); - assertFalse(server.isComplete()); - } - } - - @Test - public void testSaslAuthentication() throws Throwable { - testBasicSasl(false); - } - - @Test - public void testSaslEncryption() throws Throwable { - testBasicSasl(true); - } - - private void testBasicSasl(boolean encrypt) throws Throwable { - RpcHandler rpcHandler = mock(RpcHandler.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", JavaUtils.bytesToString(message)); - cb.onSuccess(JavaUtils.stringToBytes("Pong")); - return null; - } - }) - .when(rpcHandler) - .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); - try { - ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", JavaUtils.bytesToString(response)); - } finally { - ctx.close(); - // There should be 2 terminated events; one for the client, one for the server. - Throwable error = null; - long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); - while (deadline > System.nanoTime()) { - try { - verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); - error = null; - break; - } catch (Throwable t) { - error = t; - TimeUnit.MILLISECONDS.sleep(10); - } - } - if (error != null) { - throw error; - } - } - } - - @Test - public void testEncryptedMessage() throws Exception { - SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); - byte[] data = new byte[1024]; - new Random().nextBytes(data); - when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); - - ByteBuf msg = Unpooled.buffer(); - try { - msg.writeBytes(data); - - // Create a channel with a really small buffer compared to the data. This means that on each - // call, the outbound data will not be fully written, so the write() method should return a - // dummy count to keep the channel alive when possible. - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32); - - SaslEncryption.EncryptedMessage emsg = - new SaslEncryption.EncryptedMessage(backend, msg, 1024); - long count = emsg.transferTo(channel, emsg.transfered()); - assertTrue(count < data.length); - assertTrue(count > 0); - - // Here, the output buffer is full so nothing should be transferred. - assertEquals(0, emsg.transferTo(channel, emsg.transfered())); - - // Now there's room in the buffer, but not enough to transfer all the remaining data, - // so the dummy count should be returned. - channel.reset(); - assertEquals(1, emsg.transferTo(channel, emsg.transfered())); - - // Eventually, the whole message should be transferred. - for (int i = 0; i < data.length / 32 - 2; i++) { - channel.reset(); - assertEquals(1, emsg.transferTo(channel, emsg.transfered())); - } - - channel.reset(); - count = emsg.transferTo(channel, emsg.transfered()); - assertTrue("Unexpected count: " + count, count > 1 && count < data.length); - assertEquals(data.length, emsg.transfered()); - } finally { - msg.release(); - } - } - - @Test - public void testEncryptedMessageChunking() throws Exception { - File file = File.createTempFile("sasltest", ".txt"); - try { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - - byte[] data = new byte[8 * 1024]; - new Random().nextBytes(data); - Files.write(data, file); - - SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); - // It doesn't really matter what we return here, as long as it's not null. - when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); - - FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length()); - SaslEncryption.EncryptedMessage emsg = - new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8); - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); - while (emsg.transfered() < emsg.count()) { - channel.reset(); - emsg.transferTo(channel, emsg.transfered()); - } - - verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt()); - } finally { - file.delete(); - } - } - - @Test - public void testFileRegionEncryption() throws Exception { - final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; - System.setProperty(blockSizeConf, "1k"); - - final AtomicReference response = new AtomicReference<>(); - final File file = File.createTempFile("sasltest", ".txt"); - SaslTestCtx ctx = null; - try { - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { - @Override - public ManagedBuffer answer(InvocationOnMock invocation) { - return new FileSegmentManagedBuffer(conf, file, 0, file.length()); - } - }); - - RpcHandler rpcHandler = mock(RpcHandler.class); - when(rpcHandler.getStreamManager()).thenReturn(sm); - - byte[] data = new byte[8 * 1024]; - new Random().nextBytes(data); - Files.write(data, file); - - ctx = new SaslTestCtx(rpcHandler, true, false); - - final Object lock = new Object(); - - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - response.set((ManagedBuffer) invocation.getArguments()[1]); - response.get().retain(); - synchronized (lock) { - lock.notifyAll(); - } - return null; - } - }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - - synchronized (lock) { - ctx.client.fetchChunk(0, 0, callback); - lock.wait(10 * 1000); - } - - verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); - verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); - - byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); - assertTrue(Arrays.equals(data, received)); - } finally { - file.delete(); - if (ctx != null) { - ctx.close(); - } - if (response.get() != null) { - response.get().release(); - } - System.clearProperty(blockSizeConf); - } - } - - @Test - public void testServerAlwaysEncrypt() throws Exception { - final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; - System.setProperty(alwaysEncryptConfName, "true"); - - SaslTestCtx ctx = null; - try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); - fail("Should have failed to connect without encryption."); - } catch (Exception e) { - assertTrue(e.getCause() instanceof SaslException); - } finally { - if (ctx != null) { - ctx.close(); - } - System.clearProperty(alwaysEncryptConfName); - } - } - - @Test - public void testDataEncryptionIsActuallyEnabled() throws Exception { - // This test sets up an encrypted connection but then, using a client bootstrap, removes - // the encryption handler from the client side. This should cause the server to not be - // able to understand RPCs sent to it and thus close the connection. - SaslTestCtx ctx = null; - try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), - TimeUnit.SECONDS.toMillis(10)); - fail("Should have failed to send RPC to server."); - } catch (Exception e) { - assertFalse(e.getCause() instanceof TimeoutException); - } finally { - if (ctx != null) { - ctx.close(); - } - } - } - - @Test - public void testRpcHandlerDelegate() throws Exception { - // Tests all delegates exception for receive(), which is more complicated and already handled - // by all other tests. - RpcHandler handler = mock(RpcHandler.class); - RpcHandler saslHandler = new SaslRpcHandler(null, null, handler, null); - - saslHandler.getStreamManager(); - verify(handler).getStreamManager(); - - saslHandler.channelInactive(null); - verify(handler).channelInactive(any(TransportClient.class)); - - saslHandler.exceptionCaught(null, null); - verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); - } - - @Test - public void testDelegates() throws Exception { - Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); - for (Method m : rpcHandlerMethods) { - SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); - } - } - - private static class SaslTestCtx { - - final TransportClient client; - final TransportServer server; - - private final boolean encrypt; - private final boolean disableClientEncryption; - private final EncryptionCheckerBootstrap checker; - - SaslTestCtx( - RpcHandler rpcHandler, - boolean encrypt, - boolean disableClientEncryption) - throws Exception { - - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - - SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); - when(keyHolder.getSaslUser(anyString())).thenReturn("user"); - when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); - - TransportContext ctx = new TransportContext(conf, rpcHandler); - - this.checker = new EncryptionCheckerBootstrap(); - this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), - checker)); - - try { - List clientBootstraps = Lists.newArrayList(); - clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); - if (disableClientEncryption) { - clientBootstraps.add(new EncryptionDisablerBootstrap()); - } - - this.client = ctx.createClientFactory(clientBootstraps) - .createClient(TestUtils.getLocalHost(), server.getPort()); - } catch (Exception e) { - close(); - throw e; - } - - this.encrypt = encrypt; - this.disableClientEncryption = disableClientEncryption; - } - - void close() { - if (!disableClientEncryption) { - assertEquals(encrypt, checker.foundEncryptionHandler); - } - if (client != null) { - client.close(); - } - if (server != null) { - server.close(); - } - } - - } - - private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter - implements TransportServerBootstrap { - - boolean foundEncryptionHandler; - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - if (!foundEncryptionHandler) { - foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; - } - ctx.write(msg, promise); - } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - super.handlerRemoved(ctx); - } - - @Override - public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { - channel.pipeline().addFirst("encryptionChecker", this); - return rpcHandler; - } - - } - - private static class EncryptionDisablerBootstrap implements TransportClientBootstrap { - - @Override - public void doBootstrap(TransportClient client, Channel channel) { - channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME); - } - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java deleted file mode 100644 index c647525d8f..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import java.util.ArrayList; -import java.util.List; - -import io.netty.channel.Channel; -import org.junit.Test; -import org.mockito.Mockito; - -import org.apache.spark.network.TestManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; - -public class OneForOneStreamManagerSuite { - - @Test - public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { - OneForOneStreamManager manager = new OneForOneStreamManager(); - List buffers = new ArrayList<>(); - TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); - TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); - buffers.add(buffer1); - buffers.add(buffer2); - long streamId = manager.registerStream("appId", buffers.iterator()); - - Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerChannel(dummyChannel, streamId); - - manager.connectionTerminated(dummyChannel); - - Mockito.verify(buffer1, Mockito.times(1)).release(); - Mockito.verify(buffer2, Mockito.times(1)).release(); - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java deleted file mode 100644 index d4de4a941d..0000000000 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ /dev/null @@ -1,258 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import org.junit.AfterClass; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -public class TransportFrameDecoderSuite { - - private static Random RND = new Random(); - - @AfterClass - public static void cleanup() { - RND = null; - } - - @Test - public void testFrameDecoding() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(); - ByteBuf data = createAndFeedFrames(100, decoder, ctx); - verifyAndCloseDecoder(decoder, ctx, data); - } - - @Test - public void testInterception() throws Exception { - final int interceptedReads = 3; - TransportFrameDecoder decoder = new TransportFrameDecoder(); - TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mockChannelHandlerContext(); - - byte[] data = new byte[8]; - ByteBuf len = Unpooled.copyLong(8 + data.length); - ByteBuf dataBuf = Unpooled.wrappedBuffer(data); - - try { - decoder.setInterceptor(interceptor); - for (int i = 0; i < interceptedReads; i++) { - decoder.channelRead(ctx, dataBuf); - assertEquals(0, dataBuf.refCnt()); - dataBuf = Unpooled.wrappedBuffer(data); - } - decoder.channelRead(ctx, len); - decoder.channelRead(ctx, dataBuf); - verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); - assertEquals(0, len.refCnt()); - assertEquals(0, dataBuf.refCnt()); - } finally { - release(len); - release(dataBuf); - } - } - - @Test - public void testRetainedFrames() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); - - final AtomicInteger count = new AtomicInteger(); - final List retained = new ArrayList<>(); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); - } else { - buf.release(); - } - return null; - } - }); - - ByteBuf data = createAndFeedFrames(100, decoder, ctx); - try { - // Verify all retained buffers are readable. - for (ByteBuf b : retained) { - byte[] tmp = new byte[b.readableBytes()]; - b.readBytes(tmp); - b.release(); - } - verifyAndCloseDecoder(decoder, ctx, data); - } finally { - for (ByteBuf b : retained) { - release(b); - } - } - } - - @Test - public void testSplitLengthField() throws Exception { - byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; - ByteBuf buf = Unpooled.buffer(frame.length + 8); - buf.writeLong(frame.length + 8); - buf.writeBytes(frame); - - TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(); - try { - decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); - verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); - decoder.channelRead(ctx, buf); - verify(ctx).fireChannelRead(any(ByteBuf.class)); - assertEquals(0, buf.refCnt()); - } finally { - decoder.channelInactive(ctx); - release(buf); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testNegativeFrameSize() throws Exception { - testInvalidFrame(-1); - } - - @Test(expected = IllegalArgumentException.class) - public void testEmptyFrame() throws Exception { - // 8 because frame size includes the frame length. - testInvalidFrame(8); - } - - @Test(expected = IllegalArgumentException.class) - public void testLargeFrame() throws Exception { - // Frame length includes the frame size field, so need to add a few more bytes. - testInvalidFrame(Integer.MAX_VALUE + 9); - } - - /** - * Creates a number of randomly sized frames and feed them to the given decoder, verifying - * that the frames were read. - */ - private ByteBuf createAndFeedFrames( - int frameCount, - TransportFrameDecoder decoder, - ChannelHandlerContext ctx) throws Exception { - ByteBuf data = Unpooled.buffer(); - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - - try { - while (data.isReadable()) { - int size = RND.nextInt(4 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } catch (Exception e) { - release(data); - throw e; - } - return data; - } - - private void verifyAndCloseDecoder( - TransportFrameDecoder decoder, - ChannelHandlerContext ctx, - ByteBuf data) throws Exception { - try { - decoder.channelInactive(ctx); - assertTrue("There shouldn't be dangling references to the data.", data.release()); - } finally { - release(data); - } - } - - private void testInvalidFrame(long size) throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ByteBuf frame = Unpooled.copyLong(size); - try { - decoder.channelRead(ctx, frame); - } finally { - release(frame); - } - } - - private ChannelHandlerContext mockChannelHandlerContext() { - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); - return null; - } - }); - return ctx; - } - - private void release(ByteBuf buf) { - if (buf.refCnt() > 0) { - buf.release(buf.refCnt()); - } - } - - private static class MockInterceptor implements TransportFrameDecoder.Interceptor { - - private int remainingReads; - - MockInterceptor(int readCount) { - this.remainingReads = readCount; - } - - @Override - public boolean handle(ByteBuf data) throws Exception { - data.readerIndex(data.readerIndex() + data.readableBytes()); - assertFalse(data.isReadable()); - remainingReads -= 1; - return remainingReads != 0; - } - - @Override - public void exceptionCaught(Throwable cause) throws Exception { - - } - - @Override - public void channelInactive() throws Exception { - - } - - } - -} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties deleted file mode 100644 index e8da774f7c..0000000000 --- a/network/common/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=DEBUG, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Silence verbose logs from 3rd-party libraries. -log4j.logger.io.netty=INFO diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml deleted file mode 100644 index 810ec10ca0..0000000000 --- a/network/shuffle/pom.xml +++ /dev/null @@ -1,101 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-shuffle_2.11 - jar - Spark Project Shuffle Streaming Service - http://spark.apache.org/ - - network-shuffle - - - - - - org.apache.spark - spark-network-common_${scala.binary.version} - ${project.version} - - - - org.fusesource.leveldbjni - leveldbjni-all - 1.8 - - - - com.fasterxml.jackson.core - jackson-databind - - - - com.fasterxml.jackson.core - jackson-annotations - - - - - org.slf4j - slf4j-api - provided - - - com.google.guava - guava - - - - - org.apache.spark - spark-network-common_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - log4j - log4j - test - - - org.mockito - mockito-core - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java deleted file mode 100644 index 351c7930a9..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.lang.Override; -import java.nio.ByteBuffer; -import java.util.concurrent.ConcurrentHashMap; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.util.JavaUtils; - -/** - * A class that manages shuffle secret used by the external shuffle service. - */ -public class ShuffleSecretManager implements SecretKeyHolder { - private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); - private final ConcurrentHashMap shuffleSecretMap; - - // Spark user used for authenticating SASL connections - // Note that this must match the value in org.apache.spark.SecurityManager - private static final String SPARK_SASL_USER = "sparkSaslUser"; - - public ShuffleSecretManager() { - shuffleSecretMap = new ConcurrentHashMap(); - } - - /** - * Register an application with its secret. - * Executors need to first authenticate themselves with the same secret before - * fetching shuffle files written by other executors in this application. - */ - public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.contains(appId)) { - shuffleSecretMap.put(appId, shuffleSecret); - logger.info("Registered shuffle secret for application {}", appId); - } else { - logger.debug("Application {} already registered", appId); - } - } - - /** - * Register an application with its secret specified as a byte buffer. - */ - public void registerApp(String appId, ByteBuffer shuffleSecret) { - registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); - } - - /** - * Unregister an application along with its secret. - * This is called when the application terminates. - */ - public void unregisterApp(String appId) { - if (shuffleSecretMap.contains(appId)) { - shuffleSecretMap.remove(appId); - logger.info("Unregistered shuffle secret for application {}", appId); - } else { - logger.warn("Attempted to unregister application {} when it is not registered", appId); - } - } - - /** - * Return the Spark user for authenticating SASL connections. - */ - @Override - public String getSaslUser(String appId) { - return SPARK_SASL_USER; - } - - /** - * Return the secret key registered with the given application. - * This key is used to authenticate the executors before they can fetch shuffle files - * written by this application from the external shuffle service. If the specified - * application is not registered, return null. - */ - @Override - public String getSecretKey(String appId) { - return shuffleSecretMap.get(appId); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java deleted file mode 100644 index 138fd5389c..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.util.EventListener; - -import org.apache.spark.network.buffer.ManagedBuffer; - -public interface BlockFetchingListener extends EventListener { - /** - * Called once per successfully fetched block. After this call returns, data will be released - * automatically. If the data will be passed to another thread, the receiver should retain() - * and release() the buffer on their own, or copy the data to a new buffer. - */ - void onBlockFetchSuccess(String blockId, ManagedBuffer data); - - /** - * Called at least once per block upon failures. - */ - void onBlockFetchFailure(String blockId, Throwable exception); -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java deleted file mode 100644 index f22187a01d..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; -import org.apache.spark.network.shuffle.protocol.*; -import org.apache.spark.network.util.TransportConf; - - -/** - * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. - * - * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered - * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- - * level shuffle block. - */ -public class ExternalShuffleBlockHandler extends RpcHandler { - private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); - - @VisibleForTesting - final ExternalShuffleBlockResolver blockManager; - private final OneForOneStreamManager streamManager; - - public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { - this(new OneForOneStreamManager(), - new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); - } - - /** Enables mocking out the StreamManager and BlockManager. */ - @VisibleForTesting - public ExternalShuffleBlockHandler( - OneForOneStreamManager streamManager, - ExternalShuffleBlockResolver blockManager) { - this.streamManager = streamManager; - this.blockManager = blockManager; - } - - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); - handleMessage(msgObj, client, callback); - } - - protected void handleMessage( - BlockTransferMessage msgObj, - TransportClient client, - RpcResponseCallback callback) { - if (msgObj instanceof OpenBlocks) { - OpenBlocks msg = (OpenBlocks) msgObj; - checkAuth(client, msg.appId); - - List blocks = Lists.newArrayList(); - for (String blockId : msg.blockIds) { - blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); - logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - - } else if (msgObj instanceof RegisterExecutor) { - RegisterExecutor msg = (RegisterExecutor) msgObj; - checkAuth(client, msg.appId); - blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(ByteBuffer.wrap(new byte[0])); - - } else { - throw new UnsupportedOperationException("Unexpected message: " + msgObj); - } - } - - @Override - public StreamManager getStreamManager() { - return streamManager; - } - - /** - * Removes an application (once it has been terminated), and optionally will clean up any - * local directories associated with the executors of that application in a separate thread. - */ - public void applicationRemoved(String appId, boolean cleanupLocalDirs) { - blockManager.applicationRemoved(appId, cleanupLocalDirs); - } - - /** - * Register an (application, executor) with the given shuffle info. - * - * The "re-" is meant to highlight the intended use of this method -- when this service is - * restarted, this is used to restore the state of executors from before the restart. Normal - * registration will happen via a message handled in receive() - * - * @param appExecId - * @param executorInfo - */ - public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) { - blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo); - } - - public void close() { - blockManager.close(); - } - - private void checkAuth(TransportClient client, String appId) { - if (client.getClientId() != null && !client.getClientId().equals(appId)) { - throw new SecurityException(String.format( - "Client for %s not authorized for application %s.", client.getClientId(), appId)); - } - } - -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java deleted file mode 100644 index fe933ed650..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.*; -import java.util.*; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Charsets; -import com.google.common.base.Objects; -import com.google.common.collect.Maps; -import org.fusesource.leveldbjni.JniDBFactory; -import org.fusesource.leveldbjni.internal.NativeDB; -import org.iq80.leveldb.DB; -import org.iq80.leveldb.DBIterator; -import org.iq80.leveldb.Options; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Manages converting shuffle BlockIds into physical segments of local files, from a process outside - * of Executors. Each Executor must register its own configuration about where it stores its files - * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated - * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - */ -public class ExternalShuffleBlockResolver { - private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); - - private static final ObjectMapper mapper = new ObjectMapper(); - /** - * This a common prefix to the key for each app registration we stick in leveldb, so they - * are easy to find, since leveldb lets you search based on prefix. - */ - private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; - private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); - - // Map containing all registered executors' metadata. - @VisibleForTesting - final ConcurrentMap executors; - - // Single-threaded Java executor used to perform expensive recursive directory deletion. - private final Executor directoryCleaner; - - private final TransportConf conf; - - @VisibleForTesting - final File registeredExecutorFile; - @VisibleForTesting - final DB db; - - public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) - throws IOException { - this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( - // Add `spark` prefix because it will run in NM in Yarn mode. - NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); - } - - // Allows tests to have more control over when directories are cleaned up. - @VisibleForTesting - ExternalShuffleBlockResolver( - TransportConf conf, - File registeredExecutorFile, - Executor directoryCleaner) throws IOException { - this.conf = conf; - this.registeredExecutorFile = registeredExecutorFile; - if (registeredExecutorFile != null) { - Options options = new Options(); - options.createIfMissing(false); - options.logger(new LevelDBLogger()); - DB tmpDb; - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException e) { - if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { - logger.info("Creating state database at " + registeredExecutorFile); - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - } else { - // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new - // one, so we can keep processing new apps - logger.error("error opening leveldb file {}. Creating new file, will not be able to " + - "recover state for existing applications", registeredExecutorFile, e); - if (registeredExecutorFile.isDirectory()) { - for (File f : registeredExecutorFile.listFiles()) { - if (!f.delete()) { - logger.warn("error deleting {}", f.getPath()); - } - } - } - if (!registeredExecutorFile.delete()) { - logger.warn("error deleting {}", registeredExecutorFile.getPath()); - } - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - - } - } - // if there is a version mismatch, we throw an exception, which means the service is unusable - checkVersion(tmpDb); - executors = reloadRegisteredExecutors(tmpDb); - db = tmpDb; - } else { - db = null; - executors = Maps.newConcurrentMap(); - } - this.directoryCleaner = directoryCleaner; - } - - /** Registers a new Executor with all the configuration we need to find its shuffle files. */ - public void registerExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - AppExecId fullId = new AppExecId(appId, execId); - logger.info("Registered executor {} with {}", fullId, executorInfo); - try { - if (db != null) { - byte[] key = dbAppExecKey(fullId); - byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); - db.put(key, value); - } - } catch (Exception e) { - logger.error("Error saving registered executors", e); - } - executors.put(fullId, executorInfo); - } - - /** - * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the - * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make - * assumptions about how the hash and sort based shuffles store their data. - */ - public ManagedBuffer getBlockData(String appId, String execId, String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length < 4) { - throw new IllegalArgumentException("Unexpected block id format: " + blockId); - } else if (!blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); - } - int shuffleId = Integer.parseInt(blockIdParts[1]); - int mapId = Integer.parseInt(blockIdParts[2]); - int reduceId = Integer.parseInt(blockIdParts[3]); - - ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); - if (executor == null) { - throw new RuntimeException( - String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); - } - - if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); - } else if ("hash".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else { - throw new UnsupportedOperationException( - "Unsupported shuffle manager: " + executor.shuffleManager); - } - } - - /** - * Removes our metadata of all executors registered for the given application, and optionally - * also deletes the local directories associated with the executors of that application in a - * separate thread. - * - * It is not valid to call registerExecutor() for an executor with this appId after invoking - * this method. - */ - public void applicationRemoved(String appId, boolean cleanupLocalDirs) { - logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); - Iterator> it = executors.entrySet().iterator(); - while (it.hasNext()) { - Map.Entry entry = it.next(); - AppExecId fullId = entry.getKey(); - final ExecutorShuffleInfo executor = entry.getValue(); - - // Only touch executors associated with the appId that was removed. - if (appId.equals(fullId.appId)) { - it.remove(); - if (db != null) { - try { - db.delete(dbAppExecKey(fullId)); - } catch (IOException e) { - logger.error("Error deleting {} from executor state db", appId, e); - } - } - - if (cleanupLocalDirs) { - logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); - - // Execute the actual deletion in a different thread, as it may take some time. - directoryCleaner.execute(new Runnable() { - @Override - public void run() { - deleteExecutorDirs(executor.localDirs); - } - }); - } - } - } - } - - /** - * Synchronously deletes each directory one at a time. - * Should be executed in its own thread, as this may take a long time. - */ - private void deleteExecutorDirs(String[] dirs) { - for (String localDir : dirs) { - try { - JavaUtils.deleteRecursively(new File(localDir)); - logger.debug("Successfully cleaned up directory: " + localDir); - } catch (Exception e) { - logger.error("Failed to delete directory: " + localDir, e); - } - } - } - - /** - * Hash-based shuffle data is simply stored as one file per block. - * This logic is from FileShuffleBlockResolver. - */ - private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { - File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); - return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); - } - - /** - * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file - * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, - * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. - */ - private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { - File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); - - DataInputStream in = null; - try { - in = new DataInputStream(new FileInputStream(indexFile)); - in.skipBytes(reduceId * 8); - long offset = in.readLong(); - long nextOffset = in.readLong(); - return new FileSegmentManagedBuffer( - conf, - getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - offset, - nextOffset - offset); - } catch (IOException e) { - throw new RuntimeException("Failed to open file: " + indexFile, e); - } finally { - if (in != null) { - JavaUtils.closeQuietly(in); - } - } - } - - /** - * Hashes a filename into the corresponding local directory, in a manner consistent with - * Spark's DiskBlockManager.getFile(). - */ - @VisibleForTesting - static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { - int hash = JavaUtils.nonNegativeHash(filename); - String localDir = localDirs[hash % localDirs.length]; - int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); - } - - void close() { - if (db != null) { - try { - db.close(); - } catch (IOException e) { - logger.error("Exception closing leveldb with registered executors", e); - } - } - } - - /** Simply encodes an executor's full ID, which is appId + execId. */ - public static class AppExecId { - public final String appId; - public final String execId; - - @JsonCreator - public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) { - this.appId = appId; - this.execId = execId; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - AppExecId appExecId = (AppExecId) o; - return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .toString(); - } - } - - private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { - // we stick a common prefix on all the keys so we can find them in the DB - String appExecJson = mapper.writeValueAsString(appExecId); - String key = (APP_KEY_PREFIX + ";" + appExecJson); - return key.getBytes(Charsets.UTF_8); - } - - private static AppExecId parseDbAppExecKey(String s) throws IOException { - if (!s.startsWith(APP_KEY_PREFIX)) { - throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX); - } - String json = s.substring(APP_KEY_PREFIX.length() + 1); - AppExecId parsed = mapper.readValue(json, AppExecId.class); - return parsed; - } - - @VisibleForTesting - static ConcurrentMap reloadRegisteredExecutors(DB db) - throws IOException { - ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); - if (db != null) { - DBIterator itr = db.iterator(); - itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); - while (itr.hasNext()) { - Map.Entry e = itr.next(); - String key = new String(e.getKey(), Charsets.UTF_8); - if (!key.startsWith(APP_KEY_PREFIX)) { - break; - } - AppExecId id = parseDbAppExecKey(key); - ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); - registeredExecutors.put(id, shuffleInfo); - } - } - return registeredExecutors; - } - - private static class LevelDBLogger implements org.iq80.leveldb.Logger { - private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); - - @Override - public void log(String message) { - LOG.info(message); - } - } - - /** - * Simple major.minor versioning scheme. Any incompatible changes should be across major - * versions. Minor version differences are allowed -- meaning we should be able to read - * dbs that are either earlier *or* later on the minor version. - */ - private static void checkVersion(DB db) throws IOException { - byte[] bytes = db.get(StoreVersion.KEY); - if (bytes == null) { - storeVersion(db); - } else { - StoreVersion version = mapper.readValue(bytes, StoreVersion.class); - if (version.major != CURRENT_VERSION.major) { - throw new IOException("cannot read state DB with version " + version + ", incompatible " + - "with current version " + CURRENT_VERSION); - } - storeVersion(db); - } - } - - private static void storeVersion(DB db) throws IOException { - db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); - } - - - public static class StoreVersion { - - static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); - - public final int major; - public final int minor; - - @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { - this.major = major; - this.minor = minor; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StoreVersion that = (StoreVersion) o; - - return major == that.major && minor == that.minor; - } - - @Override - public int hashCode() { - int result = major; - result = 31 * result + minor; - return result; - } - } - -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java deleted file mode 100644 index 58ca87d9d3..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.sasl.SaslClientBootstrap; -import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.util.TransportConf; - -/** - * Client for reading shuffle blocks which points to an external (outside of executor) server. - * This is instead of reading shuffle blocks directly from other executors (via - * BlockTransferService), which has the downside of losing the shuffle data if we lose the - * executors. - */ -public class ExternalShuffleClient extends ShuffleClient { - private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); - - private final TransportConf conf; - private final boolean saslEnabled; - private final boolean saslEncryptionEnabled; - private final SecretKeyHolder secretKeyHolder; - - protected TransportClientFactory clientFactory; - protected String appId; - - /** - * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, - * then secretKeyHolder may be null. - */ - public ExternalShuffleClient( - TransportConf conf, - SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - Preconditions.checkArgument( - !saslEncryptionEnabled || saslEnabled, - "SASL encryption can only be enabled if SASL is also enabled."); - this.conf = conf; - this.secretKeyHolder = secretKeyHolder; - this.saslEnabled = saslEnabled; - this.saslEncryptionEnabled = saslEncryptionEnabled; - } - - protected void checkInit() { - assert appId != null : "Called before init()"; - } - - @Override - public void init(String appId) { - this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); - List bootstraps = Lists.newArrayList(); - if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); - } - clientFactory = context.createClientFactory(bootstraps); - } - - @Override - public void fetchBlocks( - final String host, - final int port, - final String execId, - String[] blockIds, - BlockFetchingListener listener) { - checkInit(); - logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); - try { - RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = - new RetryingBlockFetcher.BlockFetchStarter() { - @Override - public void createAndStart(String[] blockIds, BlockFetchingListener listener) - throws IOException { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); - } - }; - - int maxRetries = conf.maxIORetries(); - if (maxRetries > 0) { - // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's - // a bug in this code. We should remove the if statement once we're sure of the stability. - new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); - } else { - blockFetchStarter.createAndStart(blockIds, listener); - } - } catch (Exception e) { - logger.error("Exception while beginning fetchBlocks", e); - for (String blockId : blockIds) { - listener.onBlockFetchFailure(blockId, e); - } - } - } - - /** - * Registers this executor with an external shuffle server. This registration is required to - * inform the shuffle server about where and how we store our shuffle files. - * - * @param host Host of shuffle server. - * @param port Port of shuffle server. - * @param execId This Executor's id. - * @param executorInfo Contains all info necessary for the service to find our shuffle files. - */ - public void registerWithShuffleServer( - String host, - int port, - String execId, - ExecutorShuffleInfo executorInfo) throws IOException { - checkInit(); - TransportClient client = clientFactory.createUnmanagedClient(host, port); - try { - ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); - } finally { - client.close(); - } - } - - @Override - public void close() { - clientFactory.close(); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java deleted file mode 100644 index 1b2ddbf1ed..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.nio.ByteBuffer; -import java.util.Arrays; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.StreamHandle; - -/** - * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and - * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC - * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, - * and Java serialization is used. - * - * Note that this typically corresponds to a - * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. - */ -public class OneForOneBlockFetcher { - private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); - - private final TransportClient client; - private final OpenBlocks openMessage; - private final String[] blockIds; - private final BlockFetchingListener listener; - private final ChunkReceivedCallback chunkCallback; - - private StreamHandle streamHandle = null; - - public OneForOneBlockFetcher( - TransportClient client, - String appId, - String execId, - String[] blockIds, - BlockFetchingListener listener) { - this.client = client; - this.openMessage = new OpenBlocks(appId, execId, blockIds); - this.blockIds = blockIds; - this.listener = listener; - this.chunkCallback = new ChunkCallback(); - } - - /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ - private class ChunkCallback implements ChunkReceivedCallback { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - // On receipt of a chunk, pass it upwards as a block. - listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - } - - @Override - public void onFailure(int chunkIndex, Throwable e) { - // On receipt of a failure, fail every block from chunkIndex onwards. - String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); - failRemainingBlocks(remainingBlockIds, e); - } - } - - /** - * Begins the fetching process, calling the listener with every block fetched. - * The given message will be serialized with the Java serializer, and the RPC must return a - * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. - */ - public void start() { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } - - client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); - logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); - - // Immediately request all chunks -- we expect that the total size of the request is - // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. - for (int i = 0; i < streamHandle.numChunks; i++) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); - } - } catch (Exception e) { - logger.error("Failed while starting block fetches after success", e); - failRemainingBlocks(blockIds, e); - } - } - - @Override - public void onFailure(Throwable e) { - logger.error("Failed while starting block fetches", e); - failRemainingBlocks(blockIds, e); - } - }); - } - - /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ - private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { - for (String blockId : failedBlockIds) { - try { - listener.onBlockFetchFailure(blockId, e); - } catch (Exception e2) { - logger.error("Error in block fetch failure callback", e2); - } - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java deleted file mode 100644 index 4bb0498e5d..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.util.Collections; -import java.util.LinkedHashSet; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; - -import com.google.common.collect.Sets; -import com.google.common.util.concurrent.Uninterruptibles; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; - -/** - * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to - * IOExceptions, which we hope are due to transient network conditions. - * - * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In - * particular, the listener will be invoked exactly once per blockId, with a success or failure. - */ -public class RetryingBlockFetcher { - - /** - * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any - * remaining blocks. - */ - public static interface BlockFetchStarter { - /** - * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous - * bootstrapping followed by fully asynchronous block fetching. - * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this - * method must throw an exception. - * - * This method should always attempt to get a new TransportClient from the - * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection - * issues. - */ - void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; - } - - /** Shared executor service used for waiting and retrying. */ - private static final ExecutorService executorService = Executors.newCachedThreadPool( - NettyUtils.createThreadFactory("Block Fetch Retry")); - - private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); - - /** Used to initiate new Block Fetches on our remaining blocks. */ - private final BlockFetchStarter fetchStarter; - - /** Parent listener which we delegate all successful or permanently failed block fetches to. */ - private final BlockFetchingListener listener; - - /** Max number of times we are allowed to retry. */ - private final int maxRetries; - - /** Milliseconds to wait before each retry. */ - private final int retryWaitTime; - - // NOTE: - // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated - // while inside a synchronized block. - /** Number of times we've attempted to retry so far. */ - private int retryCount = 0; - - /** - * Set of all block ids which have not been fetched successfully or with a non-IO Exception. - * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, - * input ordering is preserved, so we always request blocks in the same order the user provided. - */ - private final LinkedHashSet outstandingBlocksIds; - - /** - * The BlockFetchingListener that is active with our current BlockFetcher. - * When we start a retry, we immediately replace this with a new Listener, which causes all any - * old Listeners to ignore all further responses. - */ - private RetryingBlockFetchListener currentListener; - - public RetryingBlockFetcher( - TransportConf conf, - BlockFetchStarter fetchStarter, - String[] blockIds, - BlockFetchingListener listener) { - this.fetchStarter = fetchStarter; - this.listener = listener; - this.maxRetries = conf.maxIORetries(); - this.retryWaitTime = conf.ioRetryWaitTimeMs(); - this.outstandingBlocksIds = Sets.newLinkedHashSet(); - Collections.addAll(outstandingBlocksIds, blockIds); - this.currentListener = new RetryingBlockFetchListener(); - } - - /** - * Initiates the fetch of all blocks provided in the constructor, with possible retries in the - * event of transient IOExceptions. - */ - public void start() { - fetchAllOutstanding(); - } - - /** - * Fires off a request to fetch all blocks that have not been fetched successfully or permanently - * failed (i.e., by a non-IOException). - */ - private void fetchAllOutstanding() { - // Start by retrieving our shared state within a synchronized block. - String[] blockIdsToFetch; - int numRetries; - RetryingBlockFetchListener myListener; - synchronized (this) { - blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); - numRetries = retryCount; - myListener = currentListener; - } - - // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. - try { - fetchStarter.createAndStart(blockIdsToFetch, myListener); - } catch (Exception e) { - logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", - blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); - - if (shouldRetry(e)) { - initiateRetry(); - } else { - for (String bid : blockIdsToFetch) { - listener.onBlockFetchFailure(bid, e); - } - } - } - } - - /** - * Lightweight method which initiates a retry in a different thread. The retry will involve - * calling fetchAllOutstanding() after a configured wait time. - */ - private synchronized void initiateRetry() { - retryCount += 1; - currentListener = new RetryingBlockFetchListener(); - - logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", - retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); - - executorService.submit(new Runnable() { - @Override - public void run() { - Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); - fetchAllOutstanding(); - } - }); - } - - /** - * Returns true if we should retry due a block fetch failure. We will retry if and only if - * the exception was an IOException and we haven't retried 'maxRetries' times already. - */ - private synchronized boolean shouldRetry(Throwable e) { - boolean isIOException = e instanceof IOException - || (e.getCause() != null && e.getCause() instanceof IOException); - boolean hasRemainingRetries = retryCount < maxRetries; - return isIOException && hasRemainingRetries; - } - - /** - * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. - * Note that in the event of a retry, we will immediately replace the 'currentListener' field, - * indicating that any responses from non-current Listeners should be ignored. - */ - private class RetryingBlockFetchListener implements BlockFetchingListener { - @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - // We will only forward this success message to our parent listener if this block request is - // outstanding and we are still the active listener. - boolean shouldForwardSuccess = false; - synchronized (RetryingBlockFetcher.this) { - if (this == currentListener && outstandingBlocksIds.contains(blockId)) { - outstandingBlocksIds.remove(blockId); - shouldForwardSuccess = true; - } - } - - // Now actually invoke the parent listener, outside of the synchronized block. - if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data); - } - } - - @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { - // We will only forward this failure to our parent listener if this block request is - // outstanding, we are still the active listener, AND we cannot retry the fetch. - boolean shouldForwardFailure = false; - synchronized (RetryingBlockFetcher.this) { - if (this == currentListener && outstandingBlocksIds.contains(blockId)) { - if (shouldRetry(exception)) { - initiateRetry(); - } else { - logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", - blockId, retryCount), exception); - outstandingBlocksIds.remove(blockId); - shouldForwardFailure = true; - } - } - } - - // Now actually invoke the parent listener, outside of the synchronized block. - if (shouldForwardFailure) { - listener.onBlockFetchFailure(blockId, exception); - } - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java deleted file mode 100644 index f72ab40690..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.Closeable; - -/** Provides an interface for reading shuffle files, either from an Executor or external service. */ -public abstract class ShuffleClient implements Closeable { - - /** - * Initializes the ShuffleClient, specifying this Executor's appId. - * Must be called before any other method on the ShuffleClient. - */ - public void init(String appId) { } - - /** - * Fetch a sequence of blocks from a remote node asynchronously, - * - * Note that this API takes a sequence so the implementation can batch requests, and does not - * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as - * the data of a block is fetched, rather than waiting for all blocks to be fetched. - */ - public abstract void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener); -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java deleted file mode 100644 index 675820308b..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.mesos; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.util.TransportConf; - -/** - * A client for talking to the external shuffle service in Mesos coarse-grained mode. - * - * This is used by the Spark driver to register with each external shuffle service on the cluster. - * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably - * after the application exits. Mesos does not provide a great alternative to do this, so Spark - * has to detect this itself. - */ -public class MesosExternalShuffleClient extends ExternalShuffleClient { - private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); - - /** - * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. - * Please refer to docs on {@link ExternalShuffleClient} for more information. - */ - public MesosExternalShuffleClient( - TransportConf conf, - SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); - } - - public void registerDriverWithShuffleService(String host, int port) throws IOException { - checkInit(); - ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); - TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(registerDriver, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully registered app " + appId + " with external shuffle service."); - } - - @Override - public void onFailure(Throwable e) { - logger.warn("Unable to register app " + appId + " with external shuffle service. " + - "Please manually remove shuffle data after driver exit. Error: " + e); - } - }); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java deleted file mode 100644 index 7fbe3384b4..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.nio.ByteBuffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; - -/** - * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or - * by Spark's NettyBlockTransferService. - * - * At a high level: - * - OpenBlock is handled by both services, but only services shuffle files for the external - * shuffle service. It returns a StreamHandle. - * - UploadBlock is only handled by the NettyBlockTransferService. - * - RegisterExecutor is only handled by the external shuffle service. - */ -public abstract class BlockTransferMessage implements Encodable { - protected abstract Type type(); - - /** Preceding every serialized message is its type, which allows us to deserialize it. */ - public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); - - private final byte id; - - private Type(int id) { - assert id < 128 : "Cannot have more than 128 message types"; - this.id = (byte) id; - } - - public byte id() { return id; } - } - - // NB: Java does not support static methods in interfaces, so we must put this in a static class. - public static class Decoder { - /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { - ByteBuf buf = Unpooled.wrappedBuffer(msg); - byte type = buf.readByte(); - switch (type) { - case 0: return OpenBlocks.decode(buf); - case 1: return UploadBlock.decode(buf); - case 2: return RegisterExecutor.decode(buf); - case 3: return StreamHandle.decode(buf); - case 4: return RegisterDriver.decode(buf); - default: throw new IllegalArgumentException("Unknown message type: " + type); - } - } - } - - /** Serializes the 'type' byte followed by the message itself. */ - public ByteBuffer toByteBuffer() { - // Allow room for encoded message, plus the type byte - ByteBuf buf = Unpooled.buffer(encodedLength() + 1); - buf.writeByte(type().id); - encode(buf); - assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.nioBuffer(); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java deleted file mode 100644 index 102d4efb8b..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.util.Arrays; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.Encoders; - -/** Contains all configuration necessary for locating the shuffle files of an executor. */ -public class ExecutorShuffleInfo implements Encodable { - /** The base set of local directories that the executor stores its shuffle files in. */ - public final String[] localDirs; - /** Number of subdirectories created within each localDir. */ - public final int subDirsPerLocalDir; - /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ - public final String shuffleManager; - - @JsonCreator - public ExecutorShuffleInfo( - @JsonProperty("localDirs") String[] localDirs, - @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, - @JsonProperty("shuffleManager") String shuffleManager) { - this.localDirs = localDirs; - this.subDirsPerLocalDir = subDirsPerLocalDir; - this.shuffleManager = shuffleManager; - } - - @Override - public int hashCode() { - return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("localDirs", Arrays.toString(localDirs)) - .add("subDirsPerLocalDir", subDirsPerLocalDir) - .add("shuffleManager", shuffleManager) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof ExecutorShuffleInfo) { - ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; - return Arrays.equals(localDirs, o.localDirs) - && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) - && Objects.equal(shuffleManager, o.shuffleManager); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.StringArrays.encodedLength(localDirs) - + 4 // int - + Encoders.Strings.encodedLength(shuffleManager); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.StringArrays.encode(buf, localDirs); - buf.writeInt(subDirsPerLocalDir); - Encoders.Strings.encode(buf, shuffleManager); - } - - public static ExecutorShuffleInfo decode(ByteBuf buf) { - String[] localDirs = Encoders.StringArrays.decode(buf); - int subDirsPerLocalDir = buf.readInt(); - String shuffleManager = Encoders.Strings.decode(buf); - return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java deleted file mode 100644 index ce954b8a28..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.util.Arrays; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** Request to read a set of blocks. Returns {@link StreamHandle}. */ -public class OpenBlocks extends BlockTransferMessage { - public final String appId; - public final String execId; - public final String[] blockIds; - - public OpenBlocks(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; - this.blockIds = blockIds; - } - - @Override - protected Type type() { return Type.OPEN_BLOCKS; } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof OpenBlocks) { - OpenBlocks o = (OpenBlocks) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + Encoders.StringArrays.encodedLength(blockIds); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - Encoders.StringArrays.encode(buf, blockIds); - } - - public static OpenBlocks decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); - String[] blockIds = Encoders.StringArrays.decode(buf); - return new OpenBlocks(appId, execId, blockIds); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java deleted file mode 100644 index 167ef33104..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * Initial registration message between an executor and its local shuffle server. - * Returns nothing (empty byte array). - */ -public class RegisterExecutor extends BlockTransferMessage { - public final String appId; - public final String execId; - public final ExecutorShuffleInfo executorInfo; - - public RegisterExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - this.appId = appId; - this.execId = execId; - this.executorInfo = executorInfo; - } - - @Override - protected Type type() { return Type.REGISTER_EXECUTOR; } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, executorInfo); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("executorInfo", executorInfo) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { - RegisterExecutor o = (RegisterExecutor) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(executorInfo, o.executorInfo); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + executorInfo.encodedLength(); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - executorInfo.encode(buf); - } - - public static RegisterExecutor decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); - ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); - return new RegisterExecutor(appId, execId, executorShuffleInfo); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java deleted file mode 100644 index 1915295aa6..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" - * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. - */ -public class StreamHandle extends BlockTransferMessage { - public final long streamId; - public final int numChunks; - - public StreamHandle(long streamId, int numChunks) { - this.streamId = streamId; - this.numChunks = numChunks; - } - - @Override - protected Type type() { return Type.STREAM_HANDLE; } - - @Override - public int hashCode() { - return Objects.hashCode(streamId, numChunks); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("streamId", streamId) - .add("numChunks", numChunks) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof StreamHandle) { - StreamHandle o = (StreamHandle) other; - return Objects.equal(streamId, o.streamId) - && Objects.equal(numChunks, o.numChunks); - } - return false; - } - - @Override - public int encodedLength() { - return 8 + 4; - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(streamId); - buf.writeInt(numChunks); - } - - public static StreamHandle decode(ByteBuf buf) { - long streamId = buf.readLong(); - int numChunks = buf.readInt(); - return new StreamHandle(streamId, numChunks); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java deleted file mode 100644 index 3caed59d50..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.util.Arrays; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - - -/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ -public class UploadBlock extends BlockTransferMessage { - public final String appId; - public final String execId; - public final String blockId; - // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in - // this package. We should avoid this hack. - public final byte[] metadata; - public final byte[] blockData; - - /** - * @param metadata Meta-information about block, typically StorageLevel. - * @param blockData The actual block's bytes. - */ - public UploadBlock( - String appId, - String execId, - String blockId, - byte[] metadata, - byte[] blockData) { - this.appId = appId; - this.execId = execId; - this.blockId = blockId; - this.metadata = metadata; - this.blockData = blockData; - } - - @Override - protected Type type() { return Type.UPLOAD_BLOCK; } - - @Override - public int hashCode() { - int objectsHashCode = Objects.hashCode(appId, execId, blockId); - return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockId", blockId) - .add("metadata size", metadata.length) - .add("block size", blockData.length) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof UploadBlock) { - UploadBlock o = (UploadBlock) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(blockId, o.blockId) - && Arrays.equals(metadata, o.metadata) - && Arrays.equals(blockData, o.blockData); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + Encoders.Strings.encodedLength(blockId) - + Encoders.ByteArrays.encodedLength(metadata) - + Encoders.ByteArrays.encodedLength(blockData); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - Encoders.Strings.encode(buf, blockId); - Encoders.ByteArrays.encode(buf, metadata); - Encoders.ByteArrays.encode(buf, blockData); - } - - public static UploadBlock decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); - String blockId = Encoders.Strings.decode(buf); - byte[] metadata = Encoders.ByteArrays.decode(buf); - byte[] blockData = Encoders.ByteArrays.decode(buf); - return new UploadBlock(appId, execId, blockId, metadata, blockData); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java deleted file mode 100644 index 94a61d6caa..0000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol.mesos; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * A message sent from the driver to register with the MesosExternalShuffleService. - */ -public class RegisterDriver extends BlockTransferMessage { - private final String appId; - - public RegisterDriver(String appId) { - this.appId = appId; - } - - public String getAppId() { return appId; } - - @Override - protected Type type() { return Type.REGISTER_DRIVER; } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - } - - @Override - public int hashCode() { - return Objects.hashCode(appId); - } - - public static RegisterDriver decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - return new RegisterDriver(appId); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java deleted file mode 100644 index 0ea631ea14..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.concurrent.atomic.AtomicReference; - -import com.google.common.collect.Lists; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.TestUtils; -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.shuffle.BlockFetchingListener; -import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; -import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; -import org.apache.spark.network.shuffle.OneForOneBlockFetcher; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class SaslIntegrationSuite { - - // Use a long timeout to account for slow / overloaded build machines. In the normal case, - // tests should finish way before the timeout expires. - private static final long TIMEOUT_MS = 10_000; - - static TransportServer server; - static TransportConf conf; - static TransportContext context; - static SecretKeyHolder secretKeyHolder; - - TransportClientFactory clientFactory; - - @BeforeClass - public static void beforeAll() throws IOException { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - context = new TransportContext(conf, new TestRpcHandler()); - - secretKeyHolder = mock(SecretKeyHolder.class); - when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1"); - when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1"); - when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2"); - when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2"); - when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); - when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password"); - - TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); - server = context.createServer(Arrays.asList(bootstrap)); - } - - - @AfterClass - public static void afterAll() { - server.close(); - } - - @After - public void afterEach() { - if (clientFactory != null) { - clientFactory.close(); - clientFactory = null; - } - } - - @Test - public void testGoodClient() throws IOException { - clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - String msg = "Hello, World!"; - ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); - assertEquals(msg, JavaUtils.bytesToString(resp)); - } - - @Test - public void testBadClient() { - SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class); - when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); - when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); - clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); - - try { - // Bootstrap should fail on startup. - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - fail("Connection should have failed."); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); - } - } - - @Test - public void testNoSaslClient() throws IOException { - clientFactory = context.createClientFactory( - Lists.newArrayList()); - - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - try { - client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); - fail("Should have failed"); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); - } - - try { - // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); - fail("Should have failed"); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); - } - } - - @Test - public void testNoSaslServer() { - RpcHandler handler = new TestRpcHandler(); - TransportContext context = new TransportContext(conf, handler); - clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - TransportServer server = context.createServer(); - try { - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); - } finally { - server.close(); - } - } - - /** - * This test is not actually testing SASL behavior, but testing that the shuffle service - * performs correct authorization checks based on the SASL authentication data. - */ - @Test - public void testAppIsolation() throws Exception { - // Start a new server with the correct RPC handler to serve block data. - ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class); - ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler( - new OneForOneStreamManager(), blockResolver); - TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); - TransportContext blockServerContext = new TransportContext(conf, blockHandler); - TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); - - TransportClient client1 = null; - TransportClient client2 = null; - TransportClientFactory clientFactory2 = null; - try { - // Create a client, and make a request to fetch blocks from a different app. - clientFactory = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - client1 = clientFactory.createClient(TestUtils.getLocalHost(), - blockServer.getPort()); - - final AtomicReference exception = new AtomicReference<>(); - - BlockFetchingListener listener = new BlockFetchingListener() { - @Override - public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - notifyAll(); - } - - @Override - public synchronized void onBlockFetchFailure(String blockId, Throwable t) { - exception.set(t); - notifyAll(); - } - }; - - String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" }; - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", - blockIds, listener); - synchronized (listener) { - fetcher.start(); - listener.wait(); - } - checkSecurityException(exception.get()); - - // Register an executor so that the next steps work. - ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( - new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); - RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); - - // Make a successful request to fetch blocks, which creates a new stream. But do not actually - // fetch any blocks, to keep the stream open. - OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); - long streamId = stream.streamId; - - // Create a second client, authenticated with a different app ID, and try to read from - // the stream created for the previous app. - clientFactory2 = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); - client2 = clientFactory2.createClient(TestUtils.getLocalHost(), - blockServer.getPort()); - - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { - notifyAll(); - } - - @Override - public synchronized void onFailure(int chunkIndex, Throwable t) { - exception.set(t); - notifyAll(); - } - }; - - exception.set(null); - synchronized (callback) { - client2.fetchChunk(streamId, 0, callback); - callback.wait(); - } - checkSecurityException(exception.get()); - } finally { - if (client1 != null) { - client1.close(); - } - if (client2 != null) { - client2.close(); - } - if (clientFactory2 != null) { - clientFactory2.close(); - } - blockServer.close(); - } - } - - /** RPC handler which simply responds with the message it received. */ - public static class TestRpcHandler extends RpcHandler { - @Override - public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { - callback.onSuccess(message); - } - - @Override - public StreamManager getStreamManager() { - return new OneForOneStreamManager(); - } - } - - private void checkSecurityException(Throwable t) { - assertNotNull("No exception was caught.", t); - assertTrue("Expected SecurityException.", - t.getMessage().contains(SecurityException.class.getName())); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java deleted file mode 100644 index 86c8609e70..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.shuffle.protocol.*; - -/** Verifies that all BlockTransferMessages can be serialized correctly. */ -public class BlockTransferMessagesSuite { - @Test - public void serializeOpenShuffleBlocks() { - checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); - checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( - new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); - checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, - new byte[] { 4, 5, 6, 7} )); - checkSerializeDeserialize(new StreamHandle(12345, 16)); - } - - private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); - assertEquals(msg, msg2); - assertEquals(msg.hashCode(), msg2.hashCode()); - assertEquals(msg.toString(), msg2.toString()); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java deleted file mode 100644 index 9379412155..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.nio.ByteBuffer; -import java.util.Iterator; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; - -import static org.junit.Assert.*; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; -import org.apache.spark.network.shuffle.protocol.UploadBlock; - -public class ExternalShuffleBlockHandlerSuite { - TransportClient client = mock(TransportClient.class); - - OneForOneStreamManager streamManager; - ExternalShuffleBlockResolver blockResolver; - RpcHandler handler; - - @Before - public void beforeEach() { - streamManager = mock(OneForOneStreamManager.class); - blockResolver = mock(ExternalShuffleBlockResolver.class); - handler = new ExternalShuffleBlockHandler(streamManager, blockResolver); - } - - @Test - public void testRegisterExecutor() { - RpcResponseCallback callback = mock(RpcResponseCallback.class); - - ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); - handler.receive(client, registerMessage, callback); - verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - - verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); - verify(callback, never()).onFailure(any(Throwable.class)); - } - - @SuppressWarnings("unchecked") - @Test - public void testOpenShuffleBlocks() { - RpcResponseCallback callback = mock(RpcResponseCallback.class); - - ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); - ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) - .toByteBuffer(); - handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - - ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); - verify(callback, times(1)).onSuccess(response.capture()); - verify(callback, never()).onFailure((Throwable) any()); - - StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); - assertEquals(2, handle.numChunks); - - @SuppressWarnings("unchecked") - ArgumentCaptor> stream = (ArgumentCaptor>) - (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); - Iterator buffers = stream.getValue(); - assertEquals(block0Marker, buffers.next()); - assertEquals(block1Marker, buffers.next()); - assertFalse(buffers.hasNext()); - } - - @Test - public void testBadMessages() { - RpcResponseCallback callback = mock(RpcResponseCallback.class); - - ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); - try { - handler.receive(client, unserializableMsg, callback); - fail("Should have thrown"); - } catch (Exception e) { - // pass - } - - ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); - try { - handler.receive(client, unexpectedMsg, callback); - fail("Should have thrown"); - } catch (UnsupportedOperationException e) { - // pass - } - - verify(callback, never()).onSuccess(any(ByteBuffer.class)); - verify(callback, never()).onFailure(any(Throwable.class)); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java deleted file mode 100644 index 60a1b8b045..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.io.CharStreams; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; -import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; - -public class ExternalShuffleBlockResolverSuite { - static String sortBlock0 = "Hello!"; - static String sortBlock1 = "World!"; - - static String hashBlock0 = "Elementary"; - static String hashBlock1 = "Tabular"; - - static TestShuffleDataContext dataContext; - - static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - - @BeforeClass - public static void beforeAll() throws IOException { - dataContext = new TestShuffleDataContext(2, 5); - - dataContext.create(); - // Write some sort and hash data. - dataContext.insertSortShuffleData(0, 0, - new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); - dataContext.insertHashShuffleData(1, 0, - new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); - } - - @AfterClass - public static void afterAll() { - dataContext.cleanup(); - } - - @Test - public void testBadRequests() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); - // Unregistered executor - try { - resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (RuntimeException e) { - assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); - } - - // Invalid shuffle manager - resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - try { - resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (UnsupportedOperationException e) { - // pass - } - - // Nonexistent shuffle block - resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("sort")); - try { - resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); - fail("Should have failed"); - } catch (Exception e) { - // pass - } - } - - @Test - public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); - resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("sort")); - - InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); - block0Stream.close(); - assertEquals(sortBlock0, block0); - - InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); - block1Stream.close(); - assertEquals(sortBlock1, block1); - } - - @Test - public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); - resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("hash")); - - InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); - block0Stream.close(); - assertEquals(hashBlock0, block0); - - InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); - block1Stream.close(); - assertEquals(hashBlock1, block1); - } - - @Test - public void jsonSerializationOfExecutorRegistration() throws IOException { - ObjectMapper mapper = new ObjectMapper(); - AppExecId appId = new AppExecId("foo", "bar"); - String appIdJson = mapper.writeValueAsString(appId); - AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); - assertEquals(parsedAppId, appId); - - ExecutorShuffleInfo shuffleInfo = - new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); - String shuffleJson = mapper.writeValueAsString(shuffleInfo); - ExecutorShuffleInfo parsedShuffleInfo = - mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); - assertEquals(parsedShuffleInfo, shuffleInfo); - - // Intentionally keep these hard-coded strings in here, to check backwards-compatability. - // its not legacy yet, but keeping this here in case anybody changes it - String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; - assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); - String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + - "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; - assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java deleted file mode 100644 index 532d7ab8d0..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.File; -import java.io.IOException; -import java.util.Random; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; - -import com.google.common.util.concurrent.MoreExecutors; -import org.junit.Test; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class ExternalShuffleCleanupSuite { - - // Same-thread Executor used to ensure cleanup happens synchronously in test thread. - Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - - @Test - public void noCleanupAndCleanup() throws IOException { - TestShuffleDataContext dataContext = createSomeData(); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); - resolver.applicationRemoved("app", false /* cleanup */); - - assertStillThere(dataContext); - - resolver.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); - resolver.applicationRemoved("app", true /* cleanup */); - - assertCleanedUp(dataContext); - } - - @Test - public void cleanupUsesExecutor() throws IOException { - TestShuffleDataContext dataContext = createSomeData(); - - final AtomicBoolean cleanupCalled = new AtomicBoolean(false); - - // Executor which does nothing to ensure we're actually using it. - Executor noThreadExecutor = new Executor() { - @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } - }; - - ExternalShuffleBlockResolver manager = - new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); - - manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); - manager.applicationRemoved("app", true); - - assertTrue(cleanupCalled.get()); - assertStillThere(dataContext); - - dataContext.cleanup(); - assertCleanedUp(dataContext); - } - - @Test - public void cleanupMultipleExecutors() throws IOException { - TestShuffleDataContext dataContext0 = createSomeData(); - TestShuffleDataContext dataContext1 = createSomeData(); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - - resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); - resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); - resolver.applicationRemoved("app", true); - - assertCleanedUp(dataContext0); - assertCleanedUp(dataContext1); - } - - @Test - public void cleanupOnlyRemovedApp() throws IOException { - TestShuffleDataContext dataContext0 = createSomeData(); - TestShuffleDataContext dataContext1 = createSomeData(); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - - resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); - resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); - - resolver.applicationRemoved("app-nonexistent", true); - assertStillThere(dataContext0); - assertStillThere(dataContext1); - - resolver.applicationRemoved("app-0", true); - assertCleanedUp(dataContext0); - assertStillThere(dataContext1); - - resolver.applicationRemoved("app-1", true); - assertCleanedUp(dataContext0); - assertCleanedUp(dataContext1); - - // Make sure it's not an error to cleanup multiple times - resolver.applicationRemoved("app-1", true); - assertCleanedUp(dataContext0); - assertCleanedUp(dataContext1); - } - - private void assertStillThere(TestShuffleDataContext dataContext) { - for (String localDir : dataContext.localDirs) { - assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); - } - } - - private void assertCleanedUp(TestShuffleDataContext dataContext) { - for (String localDir : dataContext.localDirs) { - assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); - } - } - - private TestShuffleDataContext createSomeData() throws IOException { - Random rand = new Random(123); - TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); - - dataContext.create(); - dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), - new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); - dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, - new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); - return dataContext; - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java deleted file mode 100644 index 5e706bf401..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Random; -import java.util.Set; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.TestUtils; -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class ExternalShuffleIntegrationSuite { - - static String APP_ID = "app-id"; - static String SORT_MANAGER = "sort"; - static String HASH_MANAGER = "hash"; - - // Executor 0 is sort-based - static TestShuffleDataContext dataContext0; - // Executor 1 is hash-based - static TestShuffleDataContext dataContext1; - - static ExternalShuffleBlockHandler handler; - static TransportServer server; - static TransportConf conf; - - static byte[][] exec0Blocks = new byte[][] { - new byte[123], - new byte[12345], - new byte[1234567], - }; - - static byte[][] exec1Blocks = new byte[][] { - new byte[321], - new byte[54321], - }; - - @BeforeClass - public static void beforeAll() throws IOException { - Random rand = new Random(); - - for (byte[] block : exec0Blocks) { - rand.nextBytes(block); - } - for (byte[] block: exec1Blocks) { - rand.nextBytes(block); - } - - dataContext0 = new TestShuffleDataContext(2, 5); - dataContext0.create(); - dataContext0.insertSortShuffleData(0, 0, exec0Blocks); - - dataContext1 = new TestShuffleDataContext(6, 2); - dataContext1.create(); - dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(conf, null); - TransportContext transportContext = new TransportContext(conf, handler); - server = transportContext.createServer(); - } - - @AfterClass - public static void afterAll() { - dataContext0.cleanup(); - dataContext1.cleanup(); - server.close(); - } - - @After - public void afterEach() { - handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); - } - - class FetchResult { - public Set successBlocks; - public Set failedBlocks; - public List buffers; - - public void releaseBuffers() { - for (ManagedBuffer buffer : buffers) { - buffer.release(); - } - } - } - - // Fetch a set of blocks from a pre-registered executor. - private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, server.getPort()); - } - - // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, - // to allow connecting to invalid servers. - private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { - final FetchResult res = new FetchResult(); - res.successBlocks = Collections.synchronizedSet(new HashSet()); - res.failedBlocks = Collections.synchronizedSet(new HashSet()); - res.buffers = Collections.synchronizedList(new LinkedList()); - - final Semaphore requestsRemaining = new Semaphore(0); - - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); - client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, - new BlockFetchingListener() { - @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - data.retain(); - res.successBlocks.add(blockId); - res.buffers.add(data); - requestsRemaining.release(); - } - } - } - - @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - res.failedBlocks.add(blockId); - requestsRemaining.release(); - } - } - } - }); - - if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } - client.close(); - return res; - } - - @Test - public void testFetchOneSort() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); - assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); - assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); - exec0Fetch.releaseBuffers(); - } - - @Test - public void testFetchThreeSort() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult exec0Fetch = fetchBlocks("exec-0", - new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), - exec0Fetch.successBlocks); - assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); - exec0Fetch.releaseBuffers(); - } - - @Test - public void testFetchHash() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); - assertTrue(execFetch.failedBlocks.isEmpty()); - assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); - execFetch.releaseBuffers(); - } - - @Test - public void testFetchWrongShuffle() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); - } - - @Test - public void testFetchInvalidShuffle() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); - } - - @Test - public void testFetchWrongBlockId() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "rdd_1_0_0" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); - } - - @Test - public void testFetchNonexistent() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_2_0_0" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); - } - - @Test - public void testFetchWrongExecutor() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); - } - - @Test - public void testFetchUnregisteredExecutor() throws Exception { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-2", - new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); - } - - @Test - public void testFetchNoServer() throws Exception { - System.setProperty("spark.shuffle.io.maxRetries", "0"); - try { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); - } finally { - System.clearProperty("spark.shuffle.io.maxRetries"); - } - } - - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) - throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); - client.init(APP_ID); - client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), - executorId, executorInfo); - } - - private void assertBufferListsEqual(List list0, List list1) - throws Exception { - assertEquals(list0.size(), list1.size()); - for (int i = 0; i < list0.size(); i ++) { - assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); - } - } - - private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { - ByteBuffer nio0 = buffer0.nioByteBuffer(); - ByteBuffer nio1 = buffer1.nioByteBuffer(); - - int len = nio0.remaining(); - assertEquals(nio0.remaining(), nio1.remaining()); - for (int i = 0; i < len; i ++) { - assertEquals(nio0.get(), nio1.get()); - } - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java deleted file mode 100644 index 08ddb3755b..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.IOException; -import java.util.Arrays; - -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.TestUtils; -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslServerBootstrap; -import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class ExternalShuffleSecuritySuite { - - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - TransportServer server; - - @Before - public void beforeEach() throws IOException { - TransportContext context = - new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); - TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, - new TestSecretKeyHolder("my-app-id", "secret")); - this.server = context.createServer(Arrays.asList(bootstrap)); - } - - @After - public void afterEach() { - if (server != null) { - server.close(); - server = null; - } - } - - @Test - public void testValid() throws IOException { - validate("my-app-id", "secret", false); - } - - @Test - public void testBadAppId() { - try { - validate("wrong-app-id", "secret", false); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); - } - } - - @Test - public void testBadSecret() { - try { - validate("my-app-id", "bad-secret", false); - } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); - } - } - - @Test - public void testEncryption() throws IOException { - validate("my-app-id", "secret", true); - } - - /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey, boolean encrypt) throws IOException { - ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); - client.init(appId); - // Registration either succeeds or throws an exception. - client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", - new ExecutorShuffleInfo(new String[0], 0, "")); - client.close(); - } - - /** Provides a secret key holder which always returns the given secret key, for a single appId. */ - static class TestSecretKeyHolder implements SecretKeyHolder { - private final String appId; - private final String secretKey; - - TestSecretKeyHolder(String appId, String secretKey) { - this.appId = appId; - this.secretKey = secretKey; - } - - @Override - public String getSaslUser(String appId) { - return "user"; - } - - @Override - public String getSecretKey(String appId) { - if (!appId.equals(this.appId)) { - throw new IllegalArgumentException("Wrong appId!"); - } - return secretKey; - } - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java deleted file mode 100644 index 2590b9ce4c..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.nio.ByteBuffer; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.concurrent.atomic.AtomicInteger; - -import com.google.common.collect.Maps; -import io.netty.buffer.Unpooled; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.StreamHandle; - -public class OneForOneBlockFetcherSuite { - @Test - public void testFetchOne() { - LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); - - BlockFetchingListener listener = fetchBlocks(blocks); - - verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); - } - - @Test - public void testFetchThree() { - LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); - blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); - - BlockFetchingListener listener = fetchBlocks(blocks); - - for (int i = 0; i < 3; i ++) { - verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); - } - } - - @Test - public void testFailure() { - LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", null); - blocks.put("b2", null); - - BlockFetchingListener listener = fetchBlocks(blocks); - - // Each failure will cause a failure to be invoked in all remaining block fetches. - verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); - } - - @Test - public void testFailureAndSuccess() { - LinkedHashMap blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", null); - blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); - - BlockFetchingListener listener = fetchBlocks(blocks); - - // We may call both success and failure for the same block. - verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); - verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); - } - - @Test - public void testEmptyBlockFetch() { - try { - fetchBlocks(Maps.newLinkedHashMap()); - fail(); - } catch (IllegalArgumentException e) { - assertEquals("Zero-sized blockIds array", e.getMessage()); - } - } - - /** - * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which - * simply returns the given (BlockId, Block) pairs. - * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order - * that they were inserted in. - * - * If a block's buffer is "null", an exception will be thrown instead. - */ - private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { - TransportClient client = mock(TransportClient.class); - BlockFetchingListener listener = mock(BlockFetchingListener.class); - final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); - - // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( - (ByteBuffer) invocationOnMock.getArguments()[0]); - RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); - assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); - return null; - } - }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); - - // Respond to each chunk request with a single buffer from our blocks array. - final AtomicInteger expectedChunkIndex = new AtomicInteger(0); - final Iterator blockIterator = blocks.values().iterator(); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - try { - long streamId = (Long) invocation.getArguments()[0]; - int myChunkIndex = (Integer) invocation.getArguments()[1]; - assertEquals(123, streamId); - assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); - - ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; - ManagedBuffer result = blockIterator.next(); - if (result != null) { - callback.onSuccess(myChunkIndex, result); - } else { - callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); - } - } catch (Exception e) { - e.printStackTrace(); - fail("Unexpected failure"); - } - return null; - } - }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); - - fetcher.start(); - return listener; - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java deleted file mode 100644 index 3a6ef0d3f8..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Sets; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import org.mockito.stubbing.Stubber; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.TransportConf; -import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; - -/** - * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to - * fetch the lost blocks. - */ -public class RetryingBlockFetcherSuite { - - ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); - ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); - - @Before - public void beforeEach() { - System.setProperty("spark.shuffle.io.maxRetries", "2"); - System.setProperty("spark.shuffle.io.retryWait", "0"); - } - - @After - public void afterEach() { - System.clearProperty("spark.shuffle.io.maxRetries"); - System.clearProperty("spark.shuffle.io.retryWait"); - } - - @Test - public void testNoFailures() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // Immediately return both blocks successfully. - ImmutableMap.builder() - .put("b0", block0) - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener).onBlockFetchSuccess("b0", block0); - verify(listener).onBlockFetchSuccess("b1", block1); - verifyNoMoreInteractions(listener); - } - - @Test - public void testUnrecoverableFailure() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // b0 throws a non-IOException error, so it will be failed without retry. - ImmutableMap.builder() - .put("b0", new RuntimeException("Ouch!")) - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); - verify(listener).onBlockFetchSuccess("b1", block1); - verifyNoMoreInteractions(listener); - } - - @Test - public void testSingleIOExceptionOnFirst() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // IOException will cause a retry. Since b0 fails, we will retry both. - ImmutableMap.builder() - .put("b0", new IOException("Connection failed or something")) - .put("b1", block1) - .build(), - ImmutableMap.builder() - .put("b0", block0) - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); - verifyNoMoreInteractions(listener); - } - - @Test - public void testSingleIOExceptionOnSecond() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // IOException will cause a retry. Since b1 fails, we will not retry b0. - ImmutableMap.builder() - .put("b0", block0) - .put("b1", new IOException("Connection failed or something")) - .build(), - ImmutableMap.builder() - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); - verifyNoMoreInteractions(listener); - } - - @Test - public void testTwoIOExceptions() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // b0's IOException will trigger retry, b1's will be ignored. - ImmutableMap.builder() - .put("b0", new IOException()) - .put("b1", new IOException()) - .build(), - // Next, b0 is successful and b1 errors again, so we just request that one. - ImmutableMap.builder() - .put("b0", block0) - .put("b1", new IOException()) - .build(), - // b1 returns successfully within 2 retries. - ImmutableMap.builder() - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); - verifyNoMoreInteractions(listener); - } - - @Test - public void testThreeIOExceptions() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // b0's IOException will trigger retry, b1's will be ignored. - ImmutableMap.builder() - .put("b0", new IOException()) - .put("b1", new IOException()) - .build(), - // Next, b0 is successful and b1 errors again, so we just request that one. - ImmutableMap.builder() - .put("b0", block0) - .put("b1", new IOException()) - .build(), - // b1 errors again, but this was the last retry - ImmutableMap.builder() - .put("b1", new IOException()) - .build(), - // This is not reached -- b1 has failed. - ImmutableMap.builder() - .put("b1", block1) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verifyNoMoreInteractions(listener); - } - - @Test - public void testRetryAndUnrecoverable() throws IOException { - BlockFetchingListener listener = mock(BlockFetchingListener.class); - - List> interactions = Arrays.asList( - // b0's IOException will trigger retry, subsequent messages will be ignored. - ImmutableMap.builder() - .put("b0", new IOException()) - .put("b1", new RuntimeException()) - .put("b2", block2) - .build(), - // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. - ImmutableMap.builder() - .put("b0", block0) - .put("b1", new RuntimeException()) - .put("b2", new IOException()) - .build(), - // b2 succeeds in its last retry. - ImmutableMap.builder() - .put("b2", block2) - .build() - ); - - performInteractions(interactions, listener); - - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); - verifyNoMoreInteractions(listener); - } - - /** - * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. - * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction - * means "respond to the next block fetch request with these Successful buffers and these Failure - * exceptions". We verify that the expected block ids are exactly the ones requested. - * - * If multiple interactions are supplied, they will be used in order. This is useful for encoding - * retries -- the first interaction may include an IOException, which causes a retry of some - * subset of the original blocks in a second interaction. - */ - @SuppressWarnings("unchecked") - private static void performInteractions(List> interactions, - BlockFetchingListener listener) - throws IOException { - - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); - BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); - - Stubber stub = null; - - // Contains all blockIds that are referenced across all interactions. - final LinkedHashSet blockIds = Sets.newLinkedHashSet(); - - for (final Map interaction : interactions) { - blockIds.addAll(interaction.keySet()); - - Answer answer = new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - try { - // Verify that the RetryingBlockFetcher requested the expected blocks. - String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; - String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); - assertArrayEquals(desiredBlockIds, requestedBlockIds); - - // Now actually invoke the success/failure callbacks on each block. - BlockFetchingListener retryListener = - (BlockFetchingListener) invocationOnMock.getArguments()[1]; - for (Map.Entry block : interaction.entrySet()) { - String blockId = block.getKey(); - Object blockValue = block.getValue(); - - if (blockValue instanceof ManagedBuffer) { - retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); - } else if (blockValue instanceof Exception) { - retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); - } else { - fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); - } - } - return null; - } catch (Throwable e) { - e.printStackTrace(); - throw e; - } - } - }; - - // This is either the first stub, or should be chained behind the prior ones. - if (stub == null) { - stub = doAnswer(answer); - } else { - stub.doAnswer(answer); - } - } - - assert stub != null; - stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); - String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); - new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); - } -} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java deleted file mode 100644 index 7ac1ca128a..0000000000 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; - -import com.google.common.io.Closeables; -import com.google.common.io.Files; - -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; - -/** - * Manages some sort- and hash-based shuffle data, including the creation - * and cleanup of directories that can be read by the {@link ExternalShuffleBlockResolver}. - */ -public class TestShuffleDataContext { - public final String[] localDirs; - public final int subDirsPerLocalDir; - - public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { - this.localDirs = new String[numLocalDirs]; - this.subDirsPerLocalDir = subDirsPerLocalDir; - } - - public void create() { - for (int i = 0; i < localDirs.length; i ++) { - localDirs[i] = Files.createTempDir().getAbsolutePath(); - - for (int p = 0; p < subDirsPerLocalDir; p ++) { - new File(localDirs[i], String.format("%02x", p)).mkdirs(); - } - } - } - - public void cleanup() { - for (String localDir : localDirs) { - deleteRecursively(new File(localDir)); - } - } - - /** Creates reducer blocks in a sort-based data format within our local dirs. */ - public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { - String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - - OutputStream dataStream = null; - DataOutputStream indexStream = null; - boolean suppressExceptionsDuringClose = true; - - try { - dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); - - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); - indexStream.writeLong(offset); - } - suppressExceptionsDuringClose = false; - } finally { - Closeables.close(dataStream, suppressExceptionsDuringClose); - Closeables.close(indexStream, suppressExceptionsDuringClose); - } - } - - /** Creates reducer blocks in a hash-based data format within our local dirs. */ - public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { - for (int i = 0; i < blocks.length; i ++) { - String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; - Files.write(blocks[i], - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId)); - } - } - - /** - * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this - * context's directories. - */ - public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { - return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); - } - - private static void deleteRecursively(File f) { - assert f != null; - if (f.isDirectory()) { - File[] children = f.listFiles(); - if (children != null) { - for (File child : children) { - deleteRecursively(child); - } - } - } - f.delete(); - } -} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml deleted file mode 100644 index 3cb44324f2..0000000000 --- a/network/yarn/pom.xml +++ /dev/null @@ -1,148 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-yarn_2.11 - jar - Spark Project YARN Shuffle Service - http://spark.apache.org/ - - network-yarn - - provided - ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar - org/spark-project/ - - - - - - org.apache.spark - spark-network-shuffle_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - - org.apache.hadoop - hadoop-client - - - org.slf4j - slf4j-api - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - ${shuffle.jar} - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - com.fasterxml.jackson - org.spark-project.com.fasterxml.jackson - - com.fasterxml.jackson.** - - - - - - - package - - shade - - - - - - - - org.apache.maven.plugins - maven-antrun-plugin - - - verify - - run - - - - - - - - - - - - - - - - - - Verifying dependency shading - - - - - - - - - - - diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java deleted file mode 100644 index ba6d30a74c..0000000000 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.yarn; - -import java.io.File; -import java.nio.ByteBuffer; -import java.util.List; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.server.api.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslServerBootstrap; -import org.apache.spark.network.sasl.ShuffleSecretManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; -import org.apache.spark.network.util.TransportConf; -import org.apache.spark.network.yarn.util.HadoopConfigProvider; - -/** - * An external shuffle service used by Spark on Yarn. - * - * This is intended to be a long-running auxiliary service that runs in the NodeManager process. - * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. - * The application also automatically derives the service port through `spark.shuffle.service.port` - * specified in the Yarn configuration. This is so that both the clients and the server agree on - * the same port to communicate on. - * - * The service also optionally supports authentication. This ensures that executors from one - * application cannot read the shuffle files written by those from another. This feature can be - * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. - * Note that the Spark application must also set `spark.authenticate` manually and, unlike in - * the case of the service port, will not inherit this setting from the Yarn configuration. This - * is because an application running on the same Yarn cluster may choose to not use the external - * shuffle service, in which case its setting of `spark.authenticate` should be independent of - * the service's. - */ -public class YarnShuffleService extends AuxiliaryService { - private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); - - // Port on which the shuffle server listens for fetch requests - private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; - private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; - - // Whether the shuffle server should authenticate fetch requests - private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; - private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; - - // An entity that manages the shuffle secret per application - // This is used only if authentication is enabled - private ShuffleSecretManager secretManager; - - // The actual server that serves shuffle files - private TransportServer shuffleServer = null; - - // Handles registering executors and opening shuffle blocks - @VisibleForTesting - ExternalShuffleBlockHandler blockHandler; - - // Where to store & reload executor info for recovering state after an NM restart - @VisibleForTesting - File registeredExecutorFile; - - // just for testing when you want to find an open port - @VisibleForTesting - static int boundPort = -1; - - // just for integration tests that want to look at this file -- in general not sensible as - // a static - @VisibleForTesting - static YarnShuffleService instance; - - public YarnShuffleService() { - super("spark_shuffle"); - logger.info("Initializing YARN shuffle service for Spark"); - instance = this; - } - - /** - * Return whether authentication is enabled as specified by the configuration. - * If so, fetch requests will fail unless the appropriate authentication secret - * for the application is provided. - */ - private boolean isAuthenticationEnabled() { - return secretManager != null; - } - - /** - * Start the shuffle server with the given configuration. - */ - @Override - protected void serviceInit(Configuration conf) { - - // In case this NM was killed while there were running spark applications, we need to restore - // lost state for the existing executors. We look for an existing file in the NM's local dirs. - // If we don't find one, then we choose a file to use to save the state next time. Even if - // an application was stopped while the NM was down, we expect yarn to call stopApplication() - // when it comes back - registeredExecutorFile = - findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - - TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); - // If authentication is enabled, set up the shuffle server to use a - // special RPC handler that filters out unauthenticated fetch requests - boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - try { - blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); - } catch (Exception e) { - logger.error("Failed to initialize external shuffle service", e); - } - - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - secretManager = new ShuffleSecretManager(); - bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); - } - - int port = conf.getInt( - SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, blockHandler); - shuffleServer = transportContext.createServer(port, bootstraps); - // the port should normally be fixed, but for tests its useful to find an open port - port = shuffleServer.getPort(); - boundPort = port; - String authEnabledString = authEnabled ? "enabled" : "not enabled"; - logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}. Registered executor file is {}", port, authEnabledString, - registeredExecutorFile); - } - - @Override - public void initializeApplication(ApplicationInitializationContext context) { - String appId = context.getApplicationId().toString(); - try { - ByteBuffer shuffleSecret = context.getApplicationDataForService(); - logger.info("Initializing application {}", appId); - if (isAuthenticationEnabled()) { - secretManager.registerApp(appId, shuffleSecret); - } - } catch (Exception e) { - logger.error("Exception when initializing application {}", appId, e); - } - } - - @Override - public void stopApplication(ApplicationTerminationContext context) { - String appId = context.getApplicationId().toString(); - try { - logger.info("Stopping application {}", appId); - if (isAuthenticationEnabled()) { - secretManager.unregisterApp(appId); - } - blockHandler.applicationRemoved(appId, false /* clean up local dirs */); - } catch (Exception e) { - logger.error("Exception when stopping application {}", appId, e); - } - } - - @Override - public void initializeContainer(ContainerInitializationContext context) { - ContainerId containerId = context.getContainerId(); - logger.info("Initializing container {}", containerId); - } - - @Override - public void stopContainer(ContainerTerminationContext context) { - ContainerId containerId = context.getContainerId(); - logger.info("Stopping container {}", containerId); - } - - private File findRegisteredExecutorFile(String[] localDirs) { - for (String dir: localDirs) { - File f = new File(dir, "registeredExecutors.ldb"); - if (f.exists()) { - return f; - } - } - return new File(localDirs[0], "registeredExecutors.ldb"); - } - - /** - * Close the shuffle server to clean up any associated state. - */ - @Override - protected void serviceStop() { - try { - if (shuffleServer != null) { - shuffleServer.close(); - } - if (blockHandler != null) { - blockHandler.close(); - } - } catch (Exception e) { - logger.error("Exception when stopping service", e); - } - } - - // Not currently used - @Override - public ByteBuffer getMetaData() { - return ByteBuffer.allocate(0); - } -} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java deleted file mode 100644 index 884861752e..0000000000 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.yarn.util; - -import java.util.NoSuchElementException; - -import org.apache.hadoop.conf.Configuration; - -import org.apache.spark.network.util.ConfigProvider; - -/** Use the Hadoop configuration to obtain config values. */ -public class HadoopConfigProvider extends ConfigProvider { - private final Configuration conf; - - public HadoopConfigProvider(Configuration conf) { - this.conf = conf; - } - - @Override - public String get(String name) { - String value = conf.get(name); - if (value == null) { - throw new NoSuchElementException(name); - } - return value; - } -} -- cgit v1.2.3