diff options
65 files changed, 2216 insertions, 312 deletions
diff --git a/core/pom.xml b/core/pom.xml index 6963ce4777..41296e0eca 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -50,6 +50,11 @@ <version>${project.version}</version> </dependency> <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-network-shuffle_2.10</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> <groupId>net.java.dev.jets3t</groupId> <artifactId>jets3t</artifactId> </dependency> diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4cb0bd4142..7d96962c4a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } else { + logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } @@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr new ConcurrentHashMap[Int, Array[MapStatus]] } -private[spark] object MapOutputTracker { +private[spark] object MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -381,6 +382,7 @@ private[spark] object MapOutputTracker { statuses.map { status => if (status == null) { + logError("Missing an output location for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 16c5d6648d..e2f13accdf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.netty.{NettyBlockTransferService} +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c4a8ec2e5e..f1f66d0903 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -186,11 +186,11 @@ private[spark] class Worker( private def retryConnectToMaster() { Utils.tryOrExit { connectionAttemptCount += 1 - logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount") if (registered) { registrationRetryTimer.foreach(_.cancel()) registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") tryRegisterAllMasters() if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { registrationRetryTimer.foreach(_.cancel()) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2889f59e33..c78e0ffca2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -78,7 +78,7 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - conf.set("spark.executor.id", "executor." + executorId) + conf.set("spark.executor.id", executorId) private val env = { if (!isLocal) { val port = conf.getInt("spark.executor.port", 0) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index b083f46533..210a581db4 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,16 +20,16 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Await, Future} +import scala.concurrent.{Promise, Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.Logging import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils +import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} +import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} private[spark] -abstract class BlockTransferService extends Closeable with Logging { +abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -60,10 +60,11 @@ abstract class BlockTransferService extends Closeable with Logging { * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - def fetchBlocks( - hostName: String, + override def fetchBlocks( + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit /** @@ -81,43 +82,23 @@ abstract class BlockTransferService extends Closeable with Logging { * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { + def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { // A monitor for the thread to wait on. - val lock = new Object - @volatile var result: Either[ManagedBuffer, Throwable] = null - fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { - lock.synchronized { - result = Right(exception) - lock.notify() + val result = Promise[ManagedBuffer]() + fetchBlocks(host, port, execId, Array(blockId), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + result.failure(exception) } - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - lock.synchronized { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { val ret = ByteBuffer.allocate(data.size.toInt) ret.put(data.nioByteBuffer()) ret.flip() - result = Left(new NioManagedBuffer(ret)) - lock.notify() + result.success(new NioManagedBuffer(ret)) } - } - }) + }) - // Sleep until result is no longer null - lock.synchronized { - while (result == null) { - try { - lock.wait() - } catch { - case e: InterruptedException => - } - } - } - - result match { - case Left(data) => data - case Right(e) => throw e - } + Await.result(result.future, Duration.Inf) } /** diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala deleted file mode 100644 index 8c5ffd8da6..0000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer -import java.util - -import org.apache.spark.{SparkConf, Logging} -import org.apache.spark.network.BlockFetchingListener -import org.apache.spark.network.netty.NettyMessages._ -import org.apache.spark.serializer.{JavaSerializer, Serializer} -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient} -import org.apache.spark.storage.BlockId -import org.apache.spark.util.Utils - -/** - * Responsible for holding the state for a request for a single set of blocks. This assumes that - * the chunks will be returned in the same order as requested, and that there will be exactly - * one chunk per block. - * - * Upon receipt of any block, the listener will be called back. Upon failure part way through, - * the listener will receive a failure callback for each outstanding block. - */ -class NettyBlockFetcher( - serializer: Serializer, - client: TransportClient, - blockIds: Seq[String], - listener: BlockFetchingListener) - extends Logging { - - require(blockIds.nonEmpty) - - private val ser = serializer.newInstance() - - private var streamHandle: ShuffleStreamHandle = _ - - private val chunkCallback = new ChunkReceivedCallback { - // On receipt of a chunk, pass it upwards as a block. - def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { - listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) - } - - // On receipt of a failure, fail every block from chunkIndex onwards. - def onFailure(chunkIndex: Int, e: Throwable): Unit = { - blockIds.drop(chunkIndex).foreach { blockId => - listener.onBlockFetchFailure(blockId, e); - } - } - } - - /** Begins the fetching process, calling the listener with every block fetched. */ - def start(): Unit = { - // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. - client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), - new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { - try { - streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) - logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") - - // Immediately request all chunks -- we expect that the total size of the request is - // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. - for (i <- 0 until streamHandle.numChunks) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback) - } - } catch { - case e: Exception => - logError("Failed while starting block fetches", e) - blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) - } - } - - override def onFailure(e: Throwable): Unit = { - logError("Failed while starting block fetches", e) - blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) - } - }) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 02c657e1d6..1950e7bd63 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,39 +19,41 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer +import scala.collection.JavaConversions._ + import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.shuffle.ShuffleStreamHandle import org.apache.spark.serializer.Serializer -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.client.{TransportClient, RpcResponseCallback} -import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} -import org.apache.spark.storage.{StorageLevel, BlockId} - -import scala.collection.JavaConversions._ +import org.apache.spark.storage.{BlockId, StorageLevel} object NettyMessages { - /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ case class OpenBlocks(blockIds: Seq[BlockId]) /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) - - /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ - case class ShuffleStreamHandle(streamId: Long, numChunks: Int) } /** * Serves requests to open blocks by simply registering one chunk per block requested. + * Handles opening and uploading arbitrary BlockManager blocks. + * + * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( serializer: Serializer, - streamManager: DefaultStreamManager, blockManager: BlockDataManager) extends RpcHandler with Logging { import NettyMessages._ + private val streamManager = new OneForOneStreamManager() + override def receive( client: TransportClient, messageBytes: Array[Byte], @@ -73,4 +75,6 @@ class NettyBlockRpcServer( responseContext.onSuccess(new Array[Byte](0)) } } + + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 38a3e94515..ec3000e722 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,15 +17,15 @@ package org.apache.spark.network.netty -import scala.concurrent.{Promise, Future} +import scala.concurrent.{Future, Promise} import org.apache.spark.SparkConf import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} -import org.apache.spark.network.netty.NettyMessages.UploadBlock +import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.server._ -import org.apache.spark.network.util.{ConfigProvider, TransportConf} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -37,30 +37,29 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. val serializer = new JavaSerializer(conf) - // Create a TransportConfig using SparkConf. - private[this] val transportConf = new TransportConf( - new ConfigProvider { override def get(name: String) = conf.get(name) }) - private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { - val streamManager = new DefaultStreamManager - val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) - transportContext = new TransportContext(transportConf, streamManager, rpcHandler) + val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) clientFactory = transportContext.createClientFactory() server = transportContext.createServer() + logInfo("Server created on " + server.getPort) } override def fetchBlocks( - hostname: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { - val client = clientFactory.createClient(hostname, port) - new NettyBlockFetcher(serializer, client, blockIds, listener).start() + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala new file mode 100644 index 0000000000..9fa4fa77b8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.{TransportConf, ConfigProvider} + +/** + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + */ +object SparkTransportConf { + def fromSparkConf(conf: SparkConf): TransportConf = { + new TransportConf(new ConfigProvider { + override def get(name: String): String = conf.get(name) + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 11793ea92a..f56d165dab 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -79,13 +80,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } override def fetchBlocks( - hostName: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { checkInit() - val cmId = new ConnectionManagerId(hostName, port) + val cmId = new ConnectionManagerId(host, port) val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) }) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f81fa6d808..af17b5d5d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -124,6 +124,9 @@ class DAGScheduler( /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -1064,7 +1067,9 @@ class DAGScheduler( runningStages -= failedStage } - if (failedStages.isEmpty && eventProcessActor != null) { + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty && eventProcessActor != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. @@ -1086,7 +1091,7 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } case ExceptionFailure(className, description, stackTrace, metrics) => @@ -1106,25 +1111,35 @@ class DAGScheduler( * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed + * occurred, in which case we presume all shuffle data related to this executor to be lost. + * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { + private[scheduler] def handleExecutorLost( + execId: String, + fetchFailed: Boolean, + maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + + if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnExecutor(execId) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementEpoch() + } + clearCacheLocs() } - clearCacheLocs() } else { logDebug("Additional executor lost message for " + execId + "(epoch " + currentEpoch + ")") @@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId) + dagScheduler.handleExecutorLost(execId, fetchFailed = false) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 071568cdfb..cc13f57a49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -102,6 +102,11 @@ private[spark] class Stage( } } + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { @@ -131,4 +136,9 @@ private[spark] class Stage( override def toString = "Stage " + id override def hashCode(): Int = id + + override def equals(other: Any): Boolean = other match { + case stage: Stage => stage != null && stage.id == id + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a6c23fc85a..376821f89c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -687,10 +687,11 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, + // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask]) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 1fb5b2c454..f03e8e4bf1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -62,7 +62,8 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ - +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index e9805c9c13..a48f0c9ece 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -35,6 +35,8 @@ import org.apache.spark.storage._ * as the filename postfix for data file, and ".index" as the filename postfix for index file. * */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] class IndexShuffleBlockManager extends ShuffleBlockManager { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 6cf9305977..f49917b7fe 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -74,7 +74,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockTransferService, + SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, serializer, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 746ed33b54..183a30373b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } - MapStatus(blockManager.blockManagerId, sizes) + MapStatus(blockManager.shuffleServerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 927481b72c..d75f9d7311 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 8df5ec6bde..1f012941c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { def name = "rdd_" + rddId + "_" + splitIndex } +// Format of the shuffle block ids (including data and index) should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 58510d7232..1f8de28961 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,9 +21,9 @@ import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future} import scala.util.Random import akka.actor.{ActorSystem, Props} @@ -34,8 +34,13 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -85,9 +90,38 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } + private[spark] + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + // Check that we're not using external shuffle service with consolidated shuffle files. + if (externalShuffleServiceEnabled + && conf.getBoolean("spark.shuffle.consolidateFiles", false) + && shuffleManager.isInstanceOf[HashShuffleManager]) { + throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" + + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " + + " switch to sort-based shuffle.") + } + val blockManagerId = BlockManagerId( executorId, blockTransferService.hostName, blockTransferService.port) + // Address of the server that serves this executor's shuffle files. This is either an external + // service, or just our own Executor's BlockManager. + private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + + // Client to read other executors' shuffle files. This is either an external service, or just the + // standard BlockTranserService to directly connect to other Executors. + private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { + val appId = conf.get("spark.app.id", "unknown-app-id") + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + } else { + blockTransferService + } + // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored @@ -143,10 +177,41 @@ private[spark] class BlockManager( /** * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. + * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. */ private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + + // Register Executors' configuration with the local shuffle service, if one should exist. + if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + registerWithExternalShuffleServer() + } + } + + private def registerWithExternalShuffleServer() { + logInfo("Registering executor with local external shuffle service.") + val shuffleConfig = new ExecutorShuffleInfo( + diskBlockManager.localDirs.map(_.toString), + diskBlockManager.subDirsPerLocalDir, + shuffleManager.getClass.getName) + + val MAX_ATTEMPTS = 3 + val SLEEP_TIME_SECS = 5 + + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer( + shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + val attemptsRemaining = + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000) + } + } } /** @@ -506,7 +571,7 @@ private[spark] class BlockManager( for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 99e925328a..58fba54710 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private[spark] + val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs(conf) + private[spark] val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) @@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon addShutdownHook() + /** Looks up a file by hashing it into one of our local subdirectories. */ + // This method should be kept in sync with + // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -159,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { - localDirs.foreach { localDir => - if (localDir.isDirectory() && localDir.exists()) { - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case e: Exception => - logError(s"Exception while deleting local spark dir: $localDir", e) + // Only perform cleanup if an external service is not serving our shuffle files. + if (!blockManager.externalShuffleServiceEnabled) { + localDirs.foreach { localDir => + if (localDir.isDirectory() && localDir.exists()) { + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case e: Exception => + logError(s"Exception while deleting local spark dir: $localDir", e) + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d6f3bf003..ee89c7e521 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -22,7 +22,8 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.Serializer import org.apache.spark.util.{CompletionIterator, Utils} @@ -38,8 +39,8 @@ import org.apache.spark.util.{CompletionIterator, Utils} * using too much memory. * * @param context [[TaskContext]], used for metrics update - * @param blockTransferService [[BlockTransferService]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. @@ -49,7 +50,7 @@ import org.apache.spark.util.{CompletionIterator, Utils} private[spark] final class ShuffleBlockFetcherIterator( context: TaskContext, - blockTransferService: BlockTransferService, + shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, @@ -140,7 +141,8 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val blockIds = req.blocks.map(_._1.toString) - blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + val address = req.address + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, @@ -179,7 +181,7 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address == blockManager.blockManagerId) { + if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 063895d3c5..68d378f3a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1237,6 +1237,8 @@ private[spark] object Utils extends Logging { } // Handles idiosyncracies with hash (add more as required) + // This method should be kept in sync with + // org.apache.spark.network.util.JavaUtils#nonNegativeHash(). def nonNegativeHash(obj: AnyRef): Int = { // Required ? diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 81b64c36dd..429199f207 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -202,7 +202,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService blockManager.master.getLocations(blockId).foreach { cmId => - val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString) + val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, + blockId.toString) val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala new file mode 100644 index 0000000000..792b9cd8b6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkContext._ +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient} + +/** + * This suite creates an external shuffle server and routes all shuffle fetches through it. + * Note that failures in this suite may arise due to changes in Spark that invalidate expectations + * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * we hash files into folders. + */ +class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { + var server: TransportServer = _ + var rpcHandler: ExternalShuffleBlockHandler = _ + + override def beforeAll() { + val transportConf = SparkTransportConf.fromSparkConf(conf) + rpcHandler = new ExternalShuffleBlockHandler() + val transportContext = new TransportContext(transportConf, rpcHandler) + server = transportContext.createServer() + + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.shuffle.service.port", server.getPort.toString) + } + + override def afterAll() { + server.close() + } + + // This test ensures that the external shuffle service is actually in use for the other tests. + test("using external shuffle service") { + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc.env.blockManager.externalShuffleServiceEnabled should equal(true) + sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) + + val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + + rdd.count() + rdd.count() + + // Invalidate the registered executors, disallowing access to their shuffle blocks. + rpcHandler.clearRegisteredExecutors() + + // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" + // being set. + val e = intercept[SparkException] { + rdd.count() + } + e.getMessage should include ("Fetch failure will not retry stage due to testing config") + } +} diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala index 2acc02a54f..19180e88eb 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -24,10 +24,6 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with hash-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", "hash") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.manager") + conf.set("spark.shuffle.manager", "hash") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index 840d8273cb..d78c99c2e1 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,10 +24,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { - System.setProperty("spark.shuffle.blockTransferService", "netty") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.blockTransferService") + conf.set("spark.shuffle.blockTransferService", "netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 2bdd84ce69..cda942e15a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -30,10 +30,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex val conf = new SparkConf(loadDefaults = false) + // Ensure that the DAGScheduler doesn't retry stages whose fetches fail, so that we accurately + // test that the shuffle works (rather than retrying until all blocks are local to one Executor). + conf.set("spark.test.noStageRetry", "true") + test("groupByKey without compression") { try { System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) val groups = pairs.groupByKey(4).collect() assert(groups.size === 2) @@ -47,7 +51,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -73,7 +77,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +93,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) // 10 partitions from 4 keys val NUM_BLOCKS = 10 @@ -116,7 +120,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) // 10 partitions from 4 keys val NUM_BLOCKS = 10 @@ -141,7 +145,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +158,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +172,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +199,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -209,11 +213,8 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .setAppName("test") - .setMaster("local-cluster[2,1,512]") - sc = new SparkContext(conf) + val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -226,10 +227,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - val conf = new SparkConf() - .setAppName("test") - .setMaster("local-cluster[2,1,512]") - sc = new SparkContext(conf) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 639e56c488..63358172ea 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -24,10 +24,6 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", "sort") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.manager") + conf.set("spark.shuffle.manager", "sort") } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 3925f0ccbd..bbdc9568a6 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -121,7 +121,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod } val appId = "testId" - val executorId = "executor.1" + val executorId = "1" conf.set("spark.app.id", appId) conf.set("spark.executor.id", executorId) @@ -138,7 +138,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod override val metricRegistry = new MetricRegistry() } - val executorId = "executor.1" + val executorId = "1" conf.set("spark.executor.id", executorId) val instanceName = "executor" diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 4e502cf65e..28f766570e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,22 +21,19 @@ import java.util.concurrent.Semaphore import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global -import org.apache.spark.{TaskContextImpl, TaskContext} -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} -import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer - import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.{SparkConf, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.serializer.TestSerializer - class ShuffleBlockFetcherIteratorSuite extends FunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -44,10 +41,10 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { - val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]] - val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] for (blockId <- blocks) { if (data.contains(BlockId(blockId))) { @@ -118,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -138,9 +135,9 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( @@ -201,9 +198,9 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] future { // Return the first block, and then fail. listener.onBlockFetchSuccess( diff --git a/network/common/pom.xml b/network/common/pom.xml index a33e44b63d..ea887148d9 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -85,9 +85,25 @@ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> <plugins> + <!-- Create a test-jar so network-shuffle can depend on our test utilities. --> <plugin> - <groupId>org.scalatest</groupId> - <artifactId>scalatest-maven-plugin</artifactId> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>2.2</version> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + <execution> + <id>test-jar-on-test-compile</id> + <phase>test-compile</phase> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + </executions> </plugin> </plugins> </build> diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 854aa6685f..a271841e4e 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -52,15 +52,13 @@ public class TransportContext { private final Logger logger = LoggerFactory.getLogger(TransportContext.class); private final TransportConf conf; - private final StreamManager streamManager; private final RpcHandler rpcHandler; private final MessageEncoder encoder; private final MessageDecoder decoder; - public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) { + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this.conf = conf; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); @@ -70,8 +68,14 @@ public class TransportContext { return new TransportClientFactory(this); } + /** Create a server which will attempt to bind to a specific port. */ + public TransportServer createServer(int port) { + return new TransportServer(this, port); + } + + /** Creates a new server, binding to any available ephemeral port. */ public TransportServer createServer() { - return new TransportServer(this); + return new TransportServer(this, 0); } /** @@ -109,7 +113,7 @@ public class TransportContext { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - streamManager, rpcHandler); + rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index b1732fcde2..01c143fff4 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -19,9 +19,13 @@ package org.apache.spark.network.client; import java.io.Closeable; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -129,7 +133,7 @@ public class TransportClient implements Closeable { final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); - final long requestId = UUID.randomUUID().getLeastSignificantBits(); + final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( @@ -151,6 +155,32 @@ public class TransportClient implements Closeable { }); } + /** + * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to + * a specified timeout for a response. + */ + public byte[] sendRpcSync(byte[] message, long timeoutMs) { + final SettableFuture<byte[]> result = SettableFuture.create(); + + sendRpc(message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + result.set(response); + } + + @Override + public void onFailure(Throwable e) { + result.setException(e); + } + }); + + try { + return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 10eb9ef7a0..e7fa4f6bf3 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -78,15 +78,17 @@ public class TransportClientFactory implements Closeable { * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException { + public TransportClient createClient(String remoteHost, int remotePort) { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); TransportClient cachedClient = connectionPool.get(address); - if (cachedClient != null && cachedClient.isActive()) { - return cachedClient; - } else if (cachedClient != null) { - connectionPool.remove(address, cachedClient); // Remove inactive clients. + if (cachedClient != null) { + if (cachedClient.isActive()) { + return cachedClient; + } else { + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } } logger.debug("Creating new connection to " + address); @@ -115,13 +117,14 @@ public class TransportClientFactory implements Closeable { // Connect to the remote server ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { - throw new TimeoutException( + throw new RuntimeException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - // Successful connection + // Successful connection -- in the event that two threads raced to create a client, we will + // use the first one that was put into the connectionPool and close the one we made here. assert client.get() != null : "Channel future completed successfully with null client"; TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); if (oldClient == null) { diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 7aa37efc58..5a3f003726 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,4 +1,6 @@ -package org.apache.spark.network;/* +package org.apache.spark.network.server; + +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -17,12 +19,20 @@ package org.apache.spark.network;/* import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.RpcHandler; -/** Test RpcHandler which always returns a zero-sized success. */ +/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ public class NoOpRpcHandler implements RpcHandler { + private final StreamManager streamManager; + + public NoOpRpcHandler() { + streamManager = new OneForOneStreamManager(); + } + @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - callback.onSuccess(new byte[0]); + throw new UnsupportedOperationException("Cannot handle messages"); } + + @Override + public StreamManager getStreamManager() { return streamManager; } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 9688705569..731d48d4d9 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -30,10 +30,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually - * fetched as chunks by the client. + * fetched as chunks by the client. Each registered buffer is one chunk. */ -public class DefaultStreamManager extends StreamManager { - private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); +public class OneForOneStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; private final Map<Long, StreamState> streams; @@ -51,7 +51,7 @@ public class DefaultStreamManager extends StreamManager { } } - public DefaultStreamManager() { + public OneForOneStreamManager() { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index f54a696b8f..2369dc6203 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -35,4 +35,10 @@ public interface RpcHandler { * RPC. */ void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + StreamManager getStreamManager(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 352f865935..17fe9001b3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Client on the same channel allowing us to talk back to the requester. */ private final TransportClient reverseClient; - /** Returns each chunk part of a stream. */ - private final StreamManager streamManager; - /** Handles all RPC messages. */ private final RpcHandler rpcHandler; + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + /** List of all stream ids that have been read on this handler, used for cleanup. */ private final Set<Long> streamIds; public TransportRequestHandler( Channel channel, TransportClient reverseClient, - StreamManager streamManager, RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; - this.streamManager = streamManager; this.rpcHandler = rpcHandler; + this.streamManager = rpcHandler.getStreamManager(); this.streamIds = Sets.newHashSet(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 243070750d..d1a1877a98 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -49,11 +49,11 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - public TransportServer(TransportContext context) { + public TransportServer(TransportContext context, int portToBind) { this.context = context; this.conf = context.getConf(); - init(); + init(portToBind); } public int getPort() { @@ -63,7 +63,7 @@ public class TransportServer implements Closeable { return port; } - private void init() { + private void init(int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -95,7 +95,7 @@ public class TransportServer implements Closeable { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 32ba3f5b07..40b71b0c87 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,8 +17,12 @@ package org.apache.spark.network.util; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import com.google.common.io.Closeables; import org.slf4j.Logger; @@ -35,4 +39,38 @@ public class JavaUtils { logger.error("IOException should not have been thrown.", e); } } + + // TODO: Make this configurable, do not use Java serialization! + public static <T> T deserialize(byte[] bytes) { + try { + ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); + Object out = is.readObject(); + is.close(); + return (T) out; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not deserialize object", e); + } catch (IOException e) { + throw new RuntimeException("Could not deserialize object", e); + } + } + + // TODO: Make this configurable, do not use Java serialization! + public static byte[] serialize(Object object) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream os = new ObjectOutputStream(baos); + os.writeObject(object); + os.close(); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Could not serialize object", e); + } + } + + /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ + public static int nonNegativeHash(Object obj) { + if (obj == null) { return 0; } + int hash = obj.hashCode(); + return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; + } } diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index f4e0a2426a..5f20b70678 100644 --- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.util; import java.util.NoSuchElementException; diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 80f65d9803..a68f38e0e9 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -27,9 +27,6 @@ public class TransportConf { this.conf = conf; } - /** Port the server listens on. Default to a random port. */ - public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 738dca9b6a..c415883397 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { @@ -93,7 +96,18 @@ public class ChunkFetchIntegrationSuite { } } }; - TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler()); + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 9f216dd2d7..64b457b4b3 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -35,9 +35,11 @@ import static org.junit.Assert.*; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { @@ -61,8 +63,11 @@ public class RpcIntegrationSuite { throw new RuntimeException("Thrown: " + parts[1]); } } + + @Override + public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; - TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler); + TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 3ef964616f..5a10fdb384 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -28,11 +28,11 @@ import static org.junit.Assert.assertTrue; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -44,9 +44,8 @@ public class TransportClientFactorySuite { @Before public void setUp() { conf = new TransportConf(new SystemPropertyConfigProvider()); - StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new TransportContext(conf, streamManager, rpcHandler); + context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); server2 = context.createServer(); } diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml new file mode 100644 index 0000000000..d271704d98 --- /dev/null +++ b/network/shuffle/pom.xml @@ -0,0 +1,96 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent</artifactId> + <version>1.2.0-SNAPSHOT</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <groupId>org.apache.spark</groupId> + <artifactId>spark-network-shuffle_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Shuffle Streaming Service Code</name> + <url>http://spark.apache.org/</url> + <properties> + <sbt.project.name>network-shuffle</sbt.project.name> + </properties> + + <dependencies> + <!-- Core dependencies --> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-network-common_2.10</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + + <!-- Provided dependencies --> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> + </dependency> + + <!-- Test dependencies --> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-network-common_2.10</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.novocode</groupId> + <artifactId>junit-interface</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + </build> +</project> diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 645793fde8..138fd5389c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -15,28 +15,22 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.shuffle; -import java.util.EventListener +import java.util.EventListener; -import org.apache.spark.network.buffer.ManagedBuffer - - -/** - * Listener callback interface for [[BlockTransferService.fetchBlocks]]. - */ -private[spark] -trait BlockFetchingListener extends EventListener { +import org.apache.spark.network.buffer.ManagedBuffer; +public interface BlockFetchingListener extends EventListener { /** * Called once per successfully fetched block. After this call returns, data will be released * automatically. If the data will be passed to another thread, the receiver should retain() * and release() the buffer on their own, or copy the data to a new buffer. */ - def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit + void onBlockFetchSuccess(String blockId, ManagedBuffer data); /** * Called at least once per block upon failures. */ - def onBlockFetchFailure(blockId: String, exception: Throwable): Unit + void onBlockFetchFailure(String blockId, Throwable exception); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java new file mode 100644 index 0000000000..d45e64656a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** Contains all configuration necessary for locating the shuffle files of an executor. */ +public class ExecutorShuffleInfo implements Serializable { + /** The base set of local directories that the executor stores its shuffle files in. */ + final String[] localDirs; + /** Number of subdirectories created within each localDir. */ + final int subDirsPerLocalDir; + /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ + final String shuffleManager; + + public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + this.localDirs = localDirs; + this.subDirsPerLocalDir = subDirsPerLocalDir; + this.shuffleManager = shuffleManager; + } + + @Override + public int hashCode() { + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("localDirs", Arrays.toString(localDirs)) + .add("subDirsPerLocalDir", subDirsPerLocalDir) + .add("shuffleManager", shuffleManager) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ExecutorShuffleInfo) { + ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; + return Arrays.equals(localDirs, o.localDirs) + && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java new file mode 100644 index 0000000000..a9dff31dec --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; + +/** + * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. + * + * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered + * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- + * level shuffle block. + */ +public class ExternalShuffleBlockHandler implements RpcHandler { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); + + private final ExternalShuffleBlockManager blockManager; + private final OneForOneStreamManager streamManager; + + public ExternalShuffleBlockHandler() { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager()); + } + + /** Enables mocking out the StreamManager and BlockManager. */ + @VisibleForTesting + ExternalShuffleBlockHandler( + OneForOneStreamManager streamManager, + ExternalShuffleBlockManager blockManager) { + this.streamManager = streamManager; + this.blockManager = blockManager; + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + Object msgObj = JavaUtils.deserialize(message); + + logger.trace("Received message: " + msgObj); + + if (msgObj instanceof OpenShuffleBlocks) { + OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + List<ManagedBuffer> blocks = Lists.newArrayList(); + + for (String blockId : msg.blockIds) { + blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); + } + long streamId = streamManager.registerStream(blocks.iterator()); + logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); + callback.onSuccess(JavaUtils.serialize( + new ShuffleStreamHandle(streamId, msg.blockIds.length))); + + } else if (msgObj instanceof RegisterExecutor) { + RegisterExecutor msg = (RegisterExecutor) msgObj; + blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); + callback.onSuccess(new byte[0]); + + } else { + throw new UnsupportedOperationException(String.format( + "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + } + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + + /** For testing, clears all executors registered with "RegisterExecutor". */ + @VisibleForTesting + public void clearRegisteredExecutors() { + blockManager.clearRegisteredExecutors(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java new file mode 100644 index 0000000000..6589889fe1 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.JavaUtils; + +/** + * Manages converting shuffle BlockIds into physical segments of local files, from a process outside + * of Executors. Each Executor must register its own configuration about where it stores its files + * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated + * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager. + * + * Executors with shuffle file consolidation are not currently supported, as the index is stored in + * the Executor's memory, unlike the IndexShuffleBlockManager. + */ +public class ExternalShuffleBlockManager { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); + + // Map from "appId-execId" to the executor's configuration. + private final ConcurrentHashMap<String, ExecutorShuffleInfo> executors = + new ConcurrentHashMap<String, ExecutorShuffleInfo>(); + + // Returns an id suitable for a single executor within a single application. + private String getAppExecId(String appId, String execId) { + return appId + "-" + execId; + } + + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + public void registerExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + String fullId = getAppExecId(appId, execId); + logger.info("Registered executor {} with {}", fullId, executorInfo); + executors.put(fullId, executorInfo); + } + + /** + * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the + * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make + * assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getBlockData(String appId, String execId, String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length < 4) { + throw new IllegalArgumentException("Unexpected block id format: " + blockId); + } else if (!blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); + } + int shuffleId = Integer.parseInt(blockIdParts[1]); + int mapId = Integer.parseInt(blockIdParts[2]); + int reduceId = Integer.parseInt(blockIdParts[3]); + + ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + + if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle manager: " + executor.shuffleManager); + } + } + + /** + * Hash-based shuffle data is simply stored as one file per block. + * This logic is from FileShuffleBlockManager. + */ + // TODO: Support consolidated hash shuffle files + private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { + File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length()); + } + + /** + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file + * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager, + * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + */ + private ManagedBuffer getSortBasedShuffleBlockData( + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + + DataInputStream in = null; + try { + in = new DataInputStream(new FileInputStream(indexFile)); + in.skipBytes(reduceId * 8); + long offset = in.readLong(); + long nextOffset = in.readLong(); + return new FileSegmentManagedBuffer( + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + offset, + nextOffset - offset); + } catch (IOException e) { + throw new RuntimeException("Failed to open file: " + indexFile, e); + } finally { + if (in != null) { + JavaUtils.closeQuietly(in); + } + } + } + + /** + * Hashes a filename into the corresponding local directory, in a manner consistent with + * Spark's DiskBlockManager.getFile(). + */ + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + int hash = JavaUtils.nonNegativeHash(filename); + String localDir = localDirs[hash % localDirs.length]; + int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; + return new File(new File(localDir, String.format("%02x", subDirId)), filename); + } + + /** For testing, clears all registered executors. */ + @VisibleForTesting + void clearRegisteredExecutors() { + executors.clear(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java new file mode 100644 index 0000000000..cc2f6261ca --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Client for reading shuffle blocks which points to an external (outside of executor) server. + * This is instead of reading shuffle blocks directly from other executors (via + * BlockTransferService), which has the downside of losing the shuffle data if we lose the + * executors. + */ +public class ExternalShuffleClient implements ShuffleClient { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); + + private final TransportClientFactory clientFactory; + private final String appId; + + public ExternalShuffleClient(TransportConf conf, String appId) { + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + this.clientFactory = context.createClientFactory(); + this.appId = appId; + } + + @Override + public void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener) { + logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); + try { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } catch (Exception e) { + logger.error("Exception while beginning fetchBlocks", e); + for (String blockId : blockIds) { + listener.onBlockFetchFailure(blockId, e); + } + } + } + + /** + * Registers this executor with an external shuffle server. This registration is required to + * inform the shuffle server about where and how we store our shuffle files. + * + * @param host Host of shuffle server. + * @param port Port of shuffle server. + * @param execId This Executor's id. + * @param executorInfo Contains all info necessary for the service to find our shuffle files. + */ + public void registerWithShuffleServer( + String host, + int port, + String execId, + ExecutorShuffleInfo executorInfo) { + TransportClient client = clientFactory.createClient(host, port); + byte[] registerExecutorMessage = + JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); + client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java new file mode 100644 index 0000000000..e79420ed82 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ +public class ExternalShuffleMessages { + + /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ + public static class OpenShuffleBlocks implements Serializable { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenShuffleBlocks) { + OpenShuffleBlocks o = (OpenShuffleBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + } + + /** Initial registration message between an executor and its local shuffle server. */ + public static class RegisterExecutor implements Serializable { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java new file mode 100644 index 0000000000..39b6f30f92 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.util.JavaUtils; + +/** + * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and + * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC + * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * and Java serialization is used. + * + * Note that this typically corresponds to a + * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. + */ +public class OneForOneBlockFetcher { + private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + + private final TransportClient client; + private final String[] blockIds; + private final BlockFetchingListener listener; + private final ChunkReceivedCallback chunkCallback; + + private ShuffleStreamHandle streamHandle = null; + + public OneForOneBlockFetcher( + TransportClient client, + String[] blockIds, + BlockFetchingListener listener) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + this.client = client; + this.blockIds = blockIds; + this.listener = listener; + this.chunkCallback = new ChunkCallback(); + } + + /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ + private class ChunkCallback implements ChunkReceivedCallback { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, e); + } + } + + /** + * Begins the fetching process, calling the listener with every block fetched. + * The given message will be serialized with the Java serializer, and the RPC must return a + * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + */ + public void start(Object openBlocksMessage) { + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + try { + streamHandle = JavaUtils.deserialize(response); + logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (int i = 0; i < streamHandle.numChunks; i++) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } + } catch (Exception e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + }); + } + + /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ + private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + for (String blockId : failedBlockIds) { + try { + listener.onBlockFetchFailure(blockId, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java new file mode 100644 index 0000000000..9fa87c2c6e --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +/** Provides an interface for reading shuffle files, either from an Executor or external service. */ +public interface ShuffleClient { + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + public void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java new file mode 100644 index 0000000000..9c94691224 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** + * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" + * message. This is used by {@link OneForOneBlockFetcher}. + */ +public class ShuffleStreamHandle implements Serializable { + public final long streamId; + public final int numChunks; + + public ShuffleStreamHandle(long streamId, int numChunks) { + this.streamId = streamId; + this.numChunks = numChunks; + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, numChunks); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("numChunks", numChunks) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ShuffleStreamHandle) { + ShuffleStreamHandle o = (ShuffleStreamHandle) other; + return Objects.equal(streamId, o.streamId) + && Objects.equal(numChunks, o.numChunks); + } + return false; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java new file mode 100644 index 0000000000..7939cb4d32 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.util.JavaUtils; + +public class ExternalShuffleBlockHandlerSuite { + TransportClient client = mock(TransportClient.class); + + OneForOneStreamManager streamManager; + ExternalShuffleBlockManager blockManager; + RpcHandler handler; + + @Before + public void beforeEach() { + streamManager = mock(OneForOneStreamManager.class); + blockManager = mock(ExternalShuffleBlockManager.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockManager); + } + + @Test + public void testRegisterExecutor() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); + byte[] registerMessage = JavaUtils.serialize( + new RegisterExecutor("app0", "exec1", config)); + handler.receive(client, registerMessage, callback); + verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); + + verify(callback, times(1)).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } + + @SuppressWarnings("unchecked") + @Test + public void testOpenShuffleBlocks() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + byte[] openBlocksMessage = JavaUtils.serialize( + new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); + handler.receive(client, openBlocksMessage, callback); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); + + ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure((Throwable) any()); + + ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + assertEquals(2, handle.numChunks); + + ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(stream.capture()); + Iterator<ManagedBuffer> buffers = (Iterator<ManagedBuffer>) stream.getValue(); + assertEquals(block0Marker, buffers.next()); + assertEquals(block1Marker, buffers.next()); + assertFalse(buffers.hasNext()); + } + + @Test + public void testBadMessages() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + try { + handler.receive(client, unserializableMessage, callback); + fail("Should have thrown"); + } catch (Exception e) { + // pass + } + + byte[] unexpectedMessage = JavaUtils.serialize( + new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + try { + handler.receive(client, unexpectedMessage, callback); + fail("Should have thrown"); + } catch (UnsupportedOperationException e) { + // pass + } + + verify(callback, never()).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java new file mode 100644 index 0000000000..da54797e89 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.google.common.io.CharStreams; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockManagerSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + // Unregistered executor + try { + manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + manager.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + manager.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + try { + manager.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java new file mode 100644 index 0000000000..b3bcf5fd68 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleIntegrationSuite { + + static String APP_ID = "app-id"; + static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + + // Executor 0 is sort-based + static TestShuffleDataContext dataContext0; + // Executor 1 is hash-based + static TestShuffleDataContext dataContext1; + + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + + static byte[][] exec0Blocks = new byte[][] { + new byte[123], + new byte[12345], + new byte[1234567], + }; + + static byte[][] exec1Blocks = new byte[][] { + new byte[321], + new byte[54321], + }; + + @BeforeClass + public static void beforeAll() throws IOException { + Random rand = new Random(); + + for (byte[] block : exec0Blocks) { + rand.nextBytes(block); + } + for (byte[] block: exec1Blocks) { + rand.nextBytes(block); + } + + dataContext0 = new TestShuffleDataContext(2, 5); + dataContext0.create(); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + + dataContext1 = new TestShuffleDataContext(6, 2); + dataContext1.create(); + dataContext1.insertHashShuffleData(1, 0, exec1Blocks); + + conf = new TransportConf(new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(); + TransportContext transportContext = new TransportContext(conf, handler); + server = transportContext.createServer(); + } + + @AfterClass + public static void afterAll() { + dataContext0.cleanup(); + dataContext1.cleanup(); + server.close(); + } + + @After + public void afterEach() { + handler.clearRegisteredExecutors(); + } + + class FetchResult { + public Set<String> successBlocks; + public Set<String> failedBlocks; + public List<ManagedBuffer> buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, server.getPort()); + } + + // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, + // to allow connecting to invalid servers. + private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + final FetchResult res = new FetchResult(); + res.successBlocks = Collections.synchronizedSet(new HashSet<String>()); + res.failedBlocks = Collections.synchronizedSet(new HashSet<String>()); + res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>()); + + final Semaphore requestsRemaining = new Semaphore(0); + + ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } + } + } + }); + + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + return res; + } + + @Test + public void testFetchOneSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchHash() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); + execFetch.releaseBuffers(); + } + + @Test + public void testFetchWrongShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchInvalidShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongBlockId() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "rdd_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNonexistent() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_2_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); + // Both still fail, as we start by checking for all block. + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchUnregisteredExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-2", + new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNoServer() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { + ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), + executorId, executorInfo); + } + + private void assertBufferListsEqual(List<ManagedBuffer> list0, List<byte[]> list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java new file mode 100644 index 0000000000..c18346f696 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.util.JavaUtils; + +public class OneForOneBlockFetcherSuite { + @Test + public void testFetchOne() { + LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testFetchThree() { + LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + } + } + + @Test + public void testFailure() { + LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", null); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // Each failure will cause a failure to be invoked in all remaining block fetches. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testFailureAndSuccess() { + LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // We may call both success and failure for the same block. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testEmptyBlockFetch() { + try { + fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap()); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Zero-sized blockIds array", e.getMessage()); + } + } + + /** + * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which + * simply returns the given (BlockId, Block) pairs. + * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order + * that they were inserted in. + * + * If a block's buffer is "null", an exception will be thrown instead. + */ + private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) { + TransportClient client = mock(TransportClient.class); + BlockFetchingListener listener = mock(BlockFetchingListener.class); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + + // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); + assertEquals("OpenZeBlocks", message); + return null; + } + }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + + // Respond to each chunk request with a single buffer from our blocks array. + final AtomicInteger expectedChunkIndex = new AtomicInteger(0); + final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator(); + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); + } + return null; + } + }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + + fetcher.start("OpenZeBlocks"); + return listener; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java new file mode 100644 index 0000000000..ee9482b49c --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.util.JavaUtils; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; + +public class ShuffleMessagesSuite { + @Test + public void serializeOpenShuffleBlocks() { + OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", + new String[] { "block0", "block1" }); + OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } + + @Test + public void serializeRegisterExecutor() { + RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); + RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } + + @Test + public void serializeShuffleStreamHandle() { + ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); + ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java new file mode 100644 index 0000000000..442b756467 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import com.google.common.io.Files; + +/** + * Manages some sort- and hash-based shuffle data, including the creation + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. + */ +public class TestShuffleDataContext { + private final String[] localDirs; + private final int subDirsPerLocalDir; + + public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { + this.localDirs = new String[numLocalDirs]; + this.subDirsPerLocalDir = subDirsPerLocalDir; + } + + public void create() { + for (int i = 0; i < localDirs.length; i ++) { + localDirs[i] = Files.createTempDir().getAbsolutePath(); + + for (int p = 0; p < subDirsPerLocalDir; p ++) { + new File(localDirs[i], String.format("%02x", p)).mkdirs(); + } + } + } + + public void cleanup() { + for (String localDir : localDirs) { + deleteRecursively(new File(localDir)); + } + } + + /** Creates reducer blocks in a sort-based data format within our local dirs. */ + public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + + OutputStream dataStream = new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; + indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + + dataStream.close(); + indexStream.close(); + } + + /** Creates reducer blocks in a hash-based data format within our local dirs. */ + public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + for (int i = 0; i < blocks.length; i ++) { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; + Files.write(blocks[i], + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId)); + } + } + + /** + * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this + * context's directories. + */ + public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } + + private static void deleteRecursively(File f) { + assert f != null; + if (f.isDirectory()) { + File[] children = f.listFiles(); + if (children != null) { + for (File child : children) { + deleteRecursively(child); + } + } + } + f.delete(); + } +} @@ -92,6 +92,7 @@ <module>mllib</module> <module>tools</module> <module>network/common</module> + <module>network/shuffle</module> <module>streaming</module> <module>sql/catalyst</module> <module>sql/core</module> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 77083518bb..33618f5401 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -31,11 +31,12 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, - streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "network-common", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", - "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", + "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") @@ -142,7 +143,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon).contains(x)).foreach { + streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } |