aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-01 14:37:45 -0700
committerReynold Xin <rxin@databricks.com>2014-11-01 14:37:45 -0700
commitf55218aeb1e9d638df6229b36a59a15ce5363482 (patch)
tree84e4454c224b3f14b7fcbe8259c90d06b6fd969b
parent1d4f3552037cb667971bea2e5078d8b3ce6c2eae (diff)
downloadspark-f55218aeb1e9d638df6229b36a59a15ce5363482.tar.gz
spark-f55218aeb1e9d638df6229b36a59a15ce5363482.tar.bz2
spark-f55218aeb1e9d638df6229b36a59a15ce5363482.zip
[SPARK-3796] Create external service which can serve shuffle files
This patch introduces the tooling necessary to construct an external shuffle service which is independent of Spark executors, and then use this service inside Spark. An example (just for the sake of this PR) of the service creation can be found in Worker, and the service itself is used by plugging in the StandaloneShuffleClient as Spark's ShuffleClient (setup in BlockManager). This PR continues the work from #2753, which extracted out the transport layer of Spark's block transfer into an independent package within Spark. A new package was created which contains the Spark business logic necessary to retrieve the actual shuffle data, which is completely independent of the transport layer introduced in the previous patch. Similar to the transport layer, this package must not depend on Spark as we anticipate plugging this service as a lightweight process within, say, the YARN NodeManager, and do not wish to include Spark's dependencies (including Scala itself). There are several outstanding tasks which must be complete before this PR can be merged: - [x] Complete unit testing of network/shuffle package. - [x] Performance and correctness testing on a real cluster. - [x] Remove example service instantiation from Worker.scala. There are even more shortcomings of this PR which should be addressed in followup patches: - Don't use Java serializer for RPC layer! It is not cross-version compatible. - Handle shuffle file cleanup for dead executors once the application terminates or the ContextCleaner triggers. - Documentation of the feature in the Spark docs. - Improve behavior if the shuffle service itself goes down (right now we don't blacklist it, and new executors cannot spawn on that machine). - SSL and SASL integration - Nice to have: Handle shuffle file consolidation (this would requires changes to Spark's implementation). Author: Aaron Davidson <aaron@databricks.com> Closes #3001 from aarondav/shuffle-service and squashes the following commits: 4d1f8c1 [Aaron Davidson] Remove changes to Worker 705748f [Aaron Davidson] Rename Standalone* to External* fd3928b [Aaron Davidson] Do not unregister executor outputs unduly 9883918 [Aaron Davidson] Make suggested build changes 3d62679 [Aaron Davidson] Add Spark integration test 7fe51d5 [Aaron Davidson] Fix SBT integration 56caa50 [Aaron Davidson] Address comments c8d1ac3 [Aaron Davidson] Add unit tests 2f70c0c [Aaron Davidson] Fix unit tests 5483e96 [Aaron Davidson] Fix unit tests 46a70bf [Aaron Davidson] Whoops, bracket 5ea4df6 [Aaron Davidson] [SPARK-3796] Create external service which can serve shuffle files
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala2
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockTransferService.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala71
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/HashShuffleSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/SortShuffleSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala25
-rw-r--r--network/common/pom.xml20
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java32
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java17
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java (renamed from network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java)18
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java (renamed from network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java)8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java6
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java9
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java8
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java (renamed from network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java)2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java3
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java16
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java9
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java7
-rw-r--r--network/shuffle/pom.xml96
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java (renamed from core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala)18
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java64
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java102
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java154
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java88
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java106
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java121
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java35
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java60
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java123
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java125
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java291
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java167
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java51
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java107
-rw-r--r--pom.xml1
-rw-r--r--project/SparkBuild.scala11
65 files changed, 2216 insertions, 312 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 6963ce4777..41296e0eca 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -50,6 +50,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-shuffle_2.10</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4cb0bd4142..7d96962c4a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
+ logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
@@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}
-private[spark] object MapOutputTracker {
+private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -381,6 +382,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
+ logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 16c5d6648d..e2f13accdf 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
-import org.apache.spark.network.netty.{NettyBlockTransferService}
+import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index c4a8ec2e5e..f1f66d0903 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -186,11 +186,11 @@ private[spark] class Worker(
private def retryConnectToMaster() {
Utils.tryOrExit {
connectionAttemptCount += 1
- logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount")
if (registered) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = None
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
tryRegisterAllMasters()
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel())
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 2889f59e33..c78e0ffca2 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -78,7 +78,7 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", "executor." + executorId)
+ conf.set("spark.executor.id", executorId)
private val env = {
if (!isLocal) {
val port = conf.getInt("spark.executor.port", 0)
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index b083f46533..210a581db4 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -20,16 +20,16 @@ package org.apache.spark.network
import java.io.Closeable
import java.nio.ByteBuffer
-import scala.concurrent.{Await, Future}
+import scala.concurrent.{Promise, Await, Future}
import scala.concurrent.duration.Duration
import org.apache.spark.Logging
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
-import org.apache.spark.storage.{BlockId, StorageLevel}
-import org.apache.spark.util.Utils
+import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
+import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
private[spark]
-abstract class BlockTransferService extends Closeable with Logging {
+abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -60,10 +60,11 @@ abstract class BlockTransferService extends Closeable with Logging {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- def fetchBlocks(
- hostName: String,
+ override def fetchBlocks(
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit
/**
@@ -81,43 +82,23 @@ abstract class BlockTransferService extends Closeable with Logging {
*
* It is also only available after [[init]] is invoked.
*/
- def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
- val lock = new Object
- @volatile var result: Either[ManagedBuffer, Throwable] = null
- fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
- lock.synchronized {
- result = Right(exception)
- lock.notify()
+ val result = Promise[ManagedBuffer]()
+ fetchBlocks(host, port, execId, Array(blockId),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ result.failure(exception)
}
- }
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- lock.synchronized {
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
val ret = ByteBuffer.allocate(data.size.toInt)
ret.put(data.nioByteBuffer())
ret.flip()
- result = Left(new NioManagedBuffer(ret))
- lock.notify()
+ result.success(new NioManagedBuffer(ret))
}
- }
- })
+ })
- // Sleep until result is no longer null
- lock.synchronized {
- while (result == null) {
- try {
- lock.wait()
- } catch {
- case e: InterruptedException =>
- }
- }
- }
-
- result match {
- case Left(data) => data
- case Right(e) => throw e
- }
+ Await.result(result.future, Duration.Inf)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
deleted file mode 100644
index 8c5ffd8da6..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.nio.ByteBuffer
-import java.util
-
-import org.apache.spark.{SparkConf, Logging}
-import org.apache.spark.network.BlockFetchingListener
-import org.apache.spark.network.netty.NettyMessages._
-import org.apache.spark.serializer.{JavaSerializer, Serializer}
-import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient}
-import org.apache.spark.storage.BlockId
-import org.apache.spark.util.Utils
-
-/**
- * Responsible for holding the state for a request for a single set of blocks. This assumes that
- * the chunks will be returned in the same order as requested, and that there will be exactly
- * one chunk per block.
- *
- * Upon receipt of any block, the listener will be called back. Upon failure part way through,
- * the listener will receive a failure callback for each outstanding block.
- */
-class NettyBlockFetcher(
- serializer: Serializer,
- client: TransportClient,
- blockIds: Seq[String],
- listener: BlockFetchingListener)
- extends Logging {
-
- require(blockIds.nonEmpty)
-
- private val ser = serializer.newInstance()
-
- private var streamHandle: ShuffleStreamHandle = _
-
- private val chunkCallback = new ChunkReceivedCallback {
- // On receipt of a chunk, pass it upwards as a block.
- def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
- listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
- }
-
- // On receipt of a failure, fail every block from chunkIndex onwards.
- def onFailure(chunkIndex: Int, e: Throwable): Unit = {
- blockIds.drop(chunkIndex).foreach { blockId =>
- listener.onBlockFetchFailure(blockId, e);
- }
- }
- }
-
- /** Begins the fetching process, calling the listener with every block fetched. */
- def start(): Unit = {
- // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle.
- client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(),
- new RpcResponseCallback {
- override def onSuccess(response: Array[Byte]): Unit = {
- try {
- streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response))
- logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.")
-
- // Immediately request all chunks -- we expect that the total size of the request is
- // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
- for (i <- 0 until streamHandle.numChunks) {
- client.fetchChunk(streamHandle.streamId, i, chunkCallback)
- }
- } catch {
- case e: Exception =>
- logError("Failed while starting block fetches", e)
- blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
- }
- }
-
- override def onFailure(e: Throwable): Unit = {
- logError("Failed while starting block fetches", e)
- blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
- }
- })
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 02c657e1d6..1950e7bd63 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -19,39 +19,41 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer
+import scala.collection.JavaConversions._
+
import org.apache.spark.Logging
import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
+import org.apache.spark.network.shuffle.ShuffleStreamHandle
import org.apache.spark.serializer.Serializer
-import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.client.{TransportClient, RpcResponseCallback}
-import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
-import org.apache.spark.storage.{StorageLevel, BlockId}
-
-import scala.collection.JavaConversions._
+import org.apache.spark.storage.{BlockId, StorageLevel}
object NettyMessages {
-
/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
case class OpenBlocks(blockIds: Seq[BlockId])
/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
-
- /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
- case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
}
/**
* Serves requests to open blocks by simply registering one chunk per block requested.
+ * Handles opening and uploading arbitrary BlockManager blocks.
+ *
+ * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
+ * is equivalent to one Spark-level shuffle block.
*/
class NettyBlockRpcServer(
serializer: Serializer,
- streamManager: DefaultStreamManager,
blockManager: BlockDataManager)
extends RpcHandler with Logging {
import NettyMessages._
+ private val streamManager = new OneForOneStreamManager()
+
override def receive(
client: TransportClient,
messageBytes: Array[Byte],
@@ -73,4 +75,6 @@ class NettyBlockRpcServer(
responseContext.onSuccess(new Array[Byte](0))
}
}
+
+ override def getStreamManager(): StreamManager = streamManager
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 38a3e94515..ec3000e722 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,15 +17,15 @@
package org.apache.spark.network.netty
-import scala.concurrent.{Promise, Future}
+import scala.concurrent.{Future, Promise}
import org.apache.spark.SparkConf
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
-import org.apache.spark.network.netty.NettyMessages.UploadBlock
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.server._
-import org.apache.spark.network.util.{ConfigProvider, TransportConf}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -37,30 +37,29 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
val serializer = new JavaSerializer(conf)
- // Create a TransportConfig using SparkConf.
- private[this] val transportConf = new TransportConf(
- new ConfigProvider { override def get(name: String) = conf.get(name) })
-
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val streamManager = new DefaultStreamManager
- val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
- transportContext = new TransportContext(transportConf, streamManager, rpcHandler)
+ val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
clientFactory = transportContext.createClientFactory()
server = transportContext.createServer()
+ logInfo("Server created on " + server.getPort)
}
override def fetchBlocks(
- hostname: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
+ logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
- val client = clientFactory.createClient(hostname, port)
- new NettyBlockFetcher(serializer, client, blockIds, listener).start()
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, blockIds.toArray, listener)
+ .start(OpenBlocks(blockIds.map(BlockId.apply)))
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
new file mode 100644
index 0000000000..9fa4fa77b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+
+/**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ */
+object SparkTransportConf {
+ def fromSparkConf(conf: SparkConf): TransportConf = {
+ new TransportConf(new ConfigProvider {
+ override def get(name: String): String = conf.get(name)
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index 11793ea92a..f56d165dab 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
@@ -79,13 +80,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
override def fetchBlocks(
- hostName: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
checkInit()
- val cmId = new ConnectionManagerId(hostName, port)
+ val cmId = new ConnectionManagerId(host, port)
val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
})
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f81fa6d808..af17b5d5d2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -124,6 +124,9 @@ class DAGScheduler(
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
+ /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
+ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
+
private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
@@ -1064,7 +1067,9 @@ class DAGScheduler(
runningStages -= failedStage
}
- if (failedStages.isEmpty && eventProcessActor != null) {
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
@@ -1086,7 +1091,7 @@ class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
case ExceptionFailure(className, description, stackTrace, metrics) =>
@@ -1106,25 +1111,35 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
+ * We will also assume that we've lost all shuffle blocks associated with the executor if the
+ * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
+ * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ *
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(
+ execId: String,
+ fetchFailed: Boolean,
+ maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
+
+ if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
}
- clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
@@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.handleExecutorAdded(execId, host)
case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId)
+ dagScheduler.handleExecutorLost(execId, fetchFailed = false)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 071568cdfb..cc13f57a49 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -102,6 +102,11 @@ private[spark] class Stage(
}
}
+ /**
+ * Removes all shuffle outputs associated with this executor. Note that this will also remove
+ * outputs which are served by an external shuffle server (if one exists), as they are still
+ * registered with this execId.
+ */
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
@@ -131,4 +136,9 @@ private[spark] class Stage(
override def toString = "Stage " + id
override def hashCode(): Int = id
+
+ override def equals(other: Any): Boolean = other match {
+ case stage: Stage => stage != null && stage.id == id
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a6c23fc85a..376821f89c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -687,10 +687,11 @@ private[spark] class TaskSetManager(
addPendingTask(index, readding=true)
}
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage.
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
+ // and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 1fb5b2c454..f03e8e4bf1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -62,7 +62,8 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index e9805c9c13..a48f0c9ece 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -35,6 +35,8 @@ import org.apache.spark.storage._
* as the filename postfix for data file, and ".index" as the filename postfix for index file.
*
*/
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
class IndexShuffleBlockManager extends ShuffleBlockManager {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 6cf9305977..f49917b7fe 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -74,7 +74,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
- SparkEnv.get.blockTransferService,
+ SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 746ed33b54..183a30373b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
- MapStatus(blockManager.blockManagerId, sizes)
+ MapStatus(blockManager.shuffleServerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 927481b72c..d75f9d7311 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 8df5ec6bde..1f012941c8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
+// Format of the shuffle block ids (including data and index) should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData().
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 58510d7232..1f8de28961 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,9 +21,9 @@ import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream,
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.concurrent.{Await, Future}
import scala.util.Random
import akka.actor.{ActorSystem, Props}
@@ -34,8 +34,13 @@ import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -85,9 +90,38 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
+ private[spark]
+ val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+ private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337)
+ // Check that we're not using external shuffle service with consolidated shuffle files.
+ if (externalShuffleServiceEnabled
+ && conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ && shuffleManager.isInstanceOf[HashShuffleManager]) {
+ throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated"
+ + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or "
+ + " switch to sort-based shuffle.")
+ }
+
val blockManagerId = BlockManagerId(
executorId, blockTransferService.hostName, blockTransferService.port)
+ // Address of the server that serves this executor's shuffle files. This is either an external
+ // service, or just our own Executor's BlockManager.
+ private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
+ // Client to read other executors' shuffle files. This is either an external service, or just the
+ // standard BlockTranserService to directly connect to other Executors.
+ private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
+ val appId = conf.get("spark.app.id", "unknown-app-id")
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ } else {
+ blockTransferService
+ }
+
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
// Whether to compress shuffle output that are stored
@@ -143,10 +177,41 @@ private[spark] class BlockManager(
/**
* Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
+ * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
*/
private def initialize(): Unit = {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+
+ // Register Executors' configuration with the local shuffle service, if one should exist.
+ if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ registerWithExternalShuffleServer()
+ }
+ }
+
+ private def registerWithExternalShuffleServer() {
+ logInfo("Registering executor with local external shuffle service.")
+ val shuffleConfig = new ExecutorShuffleInfo(
+ diskBlockManager.localDirs.map(_.toString),
+ diskBlockManager.subDirsPerLocalDir,
+ shuffleManager.getClass.getName)
+
+ val MAX_ATTEMPTS = 3
+ val SLEEP_TIME_SECS = 5
+
+ for (i <- 1 to MAX_ATTEMPTS) {
+ try {
+ // Synchronous and will throw an exception if we cannot connect.
+ shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer(
+ shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig)
+ return
+ } catch {
+ case e: Exception if i < MAX_ATTEMPTS =>
+ val attemptsRemaining =
+ logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ + s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
+ Thread.sleep(SLEEP_TIME_SECS * 1000)
+ }
+ }
}
/**
@@ -506,7 +571,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
val data = blockTransferService.fetchBlockSync(
- loc.host, loc.port, blockId.toString).nioByteBuffer()
+ loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 99e925328a..58fba54710 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark]
+ val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs(conf)
+ private[spark] val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
@@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
addShutdownHook()
+ /** Looks up a file by hashing it into one of our local subdirectories. */
+ // This method should be kept in sync with
+ // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -159,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
- localDirs.foreach { localDir =>
- if (localDir.isDirectory() && localDir.exists()) {
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case e: Exception =>
- logError(s"Exception while deleting local spark dir: $localDir", e)
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ if (!blockManager.externalShuffleServiceEnabled) {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while deleting local spark dir: $localDir", e)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 0d6f3bf003..ee89c7e521 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -22,7 +22,8 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{CompletionIterator, Utils}
@@ -38,8 +39,8 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
- * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
@@ -49,7 +50,7 @@ import org.apache.spark.util.{CompletionIterator, Utils}
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
- blockTransferService: BlockTransferService,
+ shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
@@ -140,7 +141,8 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
- blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ val address = req.address
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
@@ -179,7 +181,7 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address == blockManager.blockManagerId) {
+ if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 063895d3c5..68d378f3a2 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1237,6 +1237,8 @@ private[spark] object Utils extends Logging {
}
// Handles idiosyncracies with hash (add more as required)
+ // This method should be kept in sync with
+ // org.apache.spark.network.util.JavaUtils#nonNegativeHash().
def nonNegativeHash(obj: AnyRef): Int = {
// Required ?
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 81b64c36dd..429199f207 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -202,7 +202,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
val blockTransfer = SparkEnv.get.blockTransferService
blockManager.master.getLocations(blockId).foreach { cmId =>
- val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString)
+ val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
+ blockId.toString)
val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
.asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
new file mode 100644
index 0000000000..792b9cd8b6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient}
+
+/**
+ * This suite creates an external shuffle server and routes all shuffle fetches through it.
+ * Note that failures in this suite may arise due to changes in Spark that invalidate expectations
+ * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how
+ * we hash files into folders.
+ */
+class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
+ var server: TransportServer = _
+ var rpcHandler: ExternalShuffleBlockHandler = _
+
+ override def beforeAll() {
+ val transportConf = SparkTransportConf.fromSparkConf(conf)
+ rpcHandler = new ExternalShuffleBlockHandler()
+ val transportContext = new TransportContext(transportConf, rpcHandler)
+ server = transportContext.createServer()
+
+ conf.set("spark.shuffle.manager", "sort")
+ conf.set("spark.shuffle.service.enabled", "true")
+ conf.set("spark.shuffle.service.port", server.getPort.toString)
+ }
+
+ override def afterAll() {
+ server.close()
+ }
+
+ // This test ensures that the external shuffle service is actually in use for the other tests.
+ test("using external shuffle service") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc.env.blockManager.externalShuffleServiceEnabled should equal(true)
+ sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient])
+
+ val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _)
+
+ rdd.count()
+ rdd.count()
+
+ // Invalidate the registered executors, disallowing access to their shuffle blocks.
+ rpcHandler.clearRegisteredExecutors()
+
+ // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry"
+ // being set.
+ val e = intercept[SparkException] {
+ rdd.count()
+ }
+ e.getMessage should include ("Fetch failure will not retry stage due to testing config")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
index 2acc02a54f..19180e88eb 100644
--- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
@@ -24,10 +24,6 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with hash-based shuffle.
override def beforeAll() {
- System.setProperty("spark.shuffle.manager", "hash")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.manager")
+ conf.set("spark.shuffle.manager", "hash")
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index 840d8273cb..d78c99c2e1 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -24,10 +24,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
override def beforeAll() {
- System.setProperty("spark.shuffle.blockTransferService", "netty")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.blockTransferService")
+ conf.set("spark.shuffle.blockTransferService", "netty")
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 2bdd84ce69..cda942e15a 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -30,10 +30,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
val conf = new SparkConf(loadDefaults = false)
+ // Ensure that the DAGScheduler doesn't retry stages whose fetches fail, so that we accurately
+ // test that the shuffle works (rather than retrying until all blocks are local to one Executor).
+ conf.set("spark.test.noStageRetry", "true")
+
test("groupByKey without compression") {
try {
System.setProperty("spark.shuffle.compress", "false")
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test", conf)
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
assert(groups.size === 2)
@@ -47,7 +51,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
}
test("shuffle non-zero block size") {
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val NUM_BLOCKS = 3
val a = sc.parallelize(1 to 10, 2)
@@ -73,7 +77,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("shuffle serializer") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(x, new NonJavaSerializableClass(x * 2))
@@ -89,7 +93,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("zero sized blocks") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
@@ -116,7 +120,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("zero sized blocks without kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
@@ -141,7 +145,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("shuffle on mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -154,7 +158,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sorting on mutable pairs") {
// This is not in SortingSuite because of the local cluster setup.
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -168,7 +172,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("cogroup using mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3"))
@@ -195,7 +199,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("subtract mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"))
@@ -209,11 +213,8 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sort with Java non serializable class - Kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .setAppName("test")
- .setMaster("local-cluster[2,1,512]")
- sc = new SparkContext(conf)
+ val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", myConf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
@@ -226,10 +227,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sort with Java non serializable class - Java") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- val conf = new SparkConf()
- .setAppName("test")
- .setMaster("local-cluster[2,1,512]")
- sc = new SparkContext(conf)
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 639e56c488..63358172ea 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -24,10 +24,6 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
override def beforeAll() {
- System.setProperty("spark.shuffle.manager", "sort")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.manager")
+ conf.set("spark.shuffle.manager", "sort")
}
}
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index 3925f0ccbd..bbdc9568a6 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -121,7 +121,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
}
val appId = "testId"
- val executorId = "executor.1"
+ val executorId = "1"
conf.set("spark.app.id", appId)
conf.set("spark.executor.id", executorId)
@@ -138,7 +138,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
override val metricRegistry = new MetricRegistry()
}
- val executorId = "executor.1"
+ val executorId = "1"
conf.set("spark.executor.id", executorId)
val instanceName = "executor"
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 4e502cf65e..28f766570e 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -21,22 +21,19 @@ import java.util.concurrent.Semaphore
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
-import org.apache.spark.{TaskContextImpl, TaskContext}
-import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
-import org.mockito.Mockito._
import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
-
import org.scalatest.FunSuite
-import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.{SparkConf, TaskContextImpl}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.serializer.TestSerializer
-
class ShuffleBlockFetcherIteratorSuite extends FunSuite {
// Some of the tests are quite tricky because we are testing the cleanup behavior
// in the presence of faults.
@@ -44,10 +41,10 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
- val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]]
- val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
for (blockId <- blocks) {
if (data.contains(BlockId(blockId))) {
@@ -118,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
verify(blockManager, times(3)).getBlockData(any())
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any())
+ verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
}
test("release current unexhausted buffer in case the task completes early") {
@@ -138,9 +135,9 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
- val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
future {
// Return the first two blocks, and wait till task completion before returning the 3rd one
listener.onBlockFetchSuccess(
@@ -201,9 +198,9 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
- val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
diff --git a/network/common/pom.xml b/network/common/pom.xml
index a33e44b63d..ea887148d9 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -85,9 +85,25 @@
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
+ <!-- Create a test-jar so network-shuffle can depend on our test utilities. -->
<plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>2.2</version>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>test-jar-on-test-compile</id>
+ <phase>test-compile</phase>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
</plugin>
</plugins>
</build>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 854aa6685f..a271841e4e 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -52,15 +52,13 @@ public class TransportContext {
private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
- private final StreamManager streamManager;
private final RpcHandler rpcHandler;
private final MessageEncoder encoder;
private final MessageDecoder decoder;
- public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this.conf = conf;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
@@ -70,8 +68,14 @@ public class TransportContext {
return new TransportClientFactory(this);
}
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port) {
+ return new TransportServer(this, port);
+ }
+
+ /** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer() {
- return new TransportServer(this);
+ return new TransportServer(this, 0);
}
/**
@@ -109,7 +113,7 @@ public class TransportContext {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
- streamManager, rpcHandler);
+ rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index b1732fcde2..01c143fff4 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,9 +19,13 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@@ -129,7 +133,7 @@ public class TransportClient implements Closeable {
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
- final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
@@ -151,6 +155,32 @@ public class TransportClient implements Closeable {
});
}
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public byte[] sendRpcSync(byte[] message, long timeoutMs) {
+ final SettableFuture<byte[]> result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 10eb9ef7a0..e7fa4f6bf3 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -78,15 +78,17 @@ public class TransportClientFactory implements Closeable {
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ public TransportClient createClient(String remoteHost, int remotePort) {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null && cachedClient.isActive()) {
- return cachedClient;
- } else if (cachedClient != null) {
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ return cachedClient;
+ } else {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
}
logger.debug("Creating new connection to " + address);
@@ -115,13 +117,14 @@ public class TransportClientFactory implements Closeable {
// Connect to the remote server
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new TimeoutException(
+ throw new RuntimeException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}
- // Successful connection
+ // Successful connection -- in the event that two threads raced to create a client, we will
+ // use the first one that was put into the connectionPool and close the one we made here.
assert client.get() != null : "Channel future completed successfully with null client";
TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
if (oldClient == null) {
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 7aa37efc58..5a3f003726 100644
--- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,4 +1,6 @@
-package org.apache.spark.network;/*
+package org.apache.spark.network.server;
+
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -17,12 +19,20 @@ package org.apache.spark.network;/*
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.server.RpcHandler;
-/** Test RpcHandler which always returns a zero-sized success. */
+/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
public class NoOpRpcHandler implements RpcHandler {
+ private final StreamManager streamManager;
+
+ public NoOpRpcHandler() {
+ streamManager = new OneForOneStreamManager();
+ }
+
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
- callback.onSuccess(new byte[0]);
+ throw new UnsupportedOperationException("Cannot handle messages");
}
+
+ @Override
+ public StreamManager getStreamManager() { return streamManager; }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 9688705569..731d48d4d9 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -30,10 +30,10 @@ import org.apache.spark.network.buffer.ManagedBuffer;
/**
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
- * fetched as chunks by the client.
+ * fetched as chunks by the client. Each registered buffer is one chunk.
*/
-public class DefaultStreamManager extends StreamManager {
- private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+public class OneForOneStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);
private final AtomicLong nextStreamId;
private final Map<Long, StreamState> streams;
@@ -51,7 +51,7 @@ public class DefaultStreamManager extends StreamManager {
}
}
- public DefaultStreamManager() {
+ public OneForOneStreamManager() {
// For debugging purposes, start with a random stream id to help identifying different streams.
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index f54a696b8f..2369dc6203 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -35,4 +35,10 @@ public interface RpcHandler {
* RPC.
*/
void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+
+ /**
+ * Returns the StreamManager which contains the state about which streams are currently being
+ * fetched by a TransportClient.
+ */
+ StreamManager getStreamManager();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 352f865935..17fe9001b3 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -56,24 +56,23 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Client on the same channel allowing us to talk back to the requester. */
private final TransportClient reverseClient;
- /** Returns each chunk part of a stream. */
- private final StreamManager streamManager;
-
/** Handles all RPC messages. */
private final RpcHandler rpcHandler;
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;
public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
- StreamManager streamManager,
RpcHandler rpcHandler) {
this.channel = channel;
this.reverseClient = reverseClient;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
+ this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 243070750d..d1a1877a98 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -49,11 +49,11 @@ public class TransportServer implements Closeable {
private ChannelFuture channelFuture;
private int port = -1;
- public TransportServer(TransportContext context) {
+ public TransportServer(TransportContext context, int portToBind) {
this.context = context;
this.conf = context.getConf();
- init();
+ init(portToBind);
}
public int getPort() {
@@ -63,7 +63,7 @@ public class TransportServer implements Closeable {
return port;
}
- private void init() {
+ private void init(int portToBind) {
IOMode ioMode = IOMode.valueOf(conf.ioMode());
EventLoopGroup bossGroup =
@@ -95,7 +95,7 @@ public class TransportServer implements Closeable {
}
});
- channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture = bootstrap.bind(new InetSocketAddress(portToBind));
channelFuture.syncUninterruptibly();
port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 32ba3f5b07..40b71b0c87 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -17,8 +17,12 @@
package org.apache.spark.network.util;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
@@ -35,4 +39,38 @@ public class JavaUtils {
logger.error("IOException should not have been thrown.", e);
}
}
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static <T> T deserialize(byte[] bytes) {
+ try {
+ ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes));
+ Object out = is.readObject();
+ is.close();
+ return (T) out;
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ } catch (IOException e) {
+ throw new RuntimeException("Could not deserialize object", e);
+ }
+ }
+
+ // TODO: Make this configurable, do not use Java serialization!
+ public static byte[] serialize(Object object) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream os = new ObjectOutputStream(baos);
+ os.writeObject(object);
+ os.close();
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException("Could not serialize object", e);
+ }
+ }
+
+ /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */
+ public static int nonNegativeHash(Object obj) {
+ if (obj == null) { return 0; }
+ int hash = obj.hashCode();
+ return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0;
+ }
}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
index f4e0a2426a..5f20b70678 100644
--- a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.network;
+package org.apache.spark.network.util;
import java.util.NoSuchElementException;
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 80f65d9803..a68f38e0e9 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -27,9 +27,6 @@ public class TransportConf {
this.conf = conf;
}
- /** Port the server listens on. Default to a random port. */
- public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
-
/** IO mode: nio or epoll */
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 738dca9b6a..c415883397 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -41,10 +41,13 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class ChunkFetchIntegrationSuite {
@@ -93,7 +96,18 @@ public class ChunkFetchIntegrationSuite {
}
}
};
- TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ TransportContext context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 9f216dd2d7..64b457b4b3 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -35,9 +35,11 @@ import static org.junit.Assert.*;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class RpcIntegrationSuite {
@@ -61,8 +63,11 @@ public class RpcIntegrationSuite {
throw new RuntimeException("Thrown: " + parts[1]);
}
}
+
+ @Override
+ public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
};
- TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ TransportContext context = new TransportContext(conf, rpcHandler);
server = context.createServer();
clientFactory = context.createClientFactory();
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 3ef964616f..5a10fdb384 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -28,11 +28,11 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class TransportClientFactorySuite {
@@ -44,9 +44,8 @@ public class TransportClientFactorySuite {
@Before
public void setUp() {
conf = new TransportConf(new SystemPropertyConfigProvider());
- StreamManager streamManager = new DefaultStreamManager();
RpcHandler rpcHandler = new NoOpRpcHandler();
- context = new TransportContext(conf, streamManager, rpcHandler);
+ context = new TransportContext(conf, rpcHandler);
server1 = context.createServer();
server2 = context.createServer();
}
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
new file mode 100644
index 0000000000..d271704d98
--- /dev/null
+++ b/network/shuffle/pom.xml
@@ -0,0 +1,96 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-shuffle_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Shuffle Streaming Service Code</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network-shuffle</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_2.10</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-network-common_2.10</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
index 645793fde8..138fd5389c 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java
@@ -15,28 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark.network
+package org.apache.spark.network.shuffle;
-import java.util.EventListener
+import java.util.EventListener;
-import org.apache.spark.network.buffer.ManagedBuffer
-
-
-/**
- * Listener callback interface for [[BlockTransferService.fetchBlocks]].
- */
-private[spark]
-trait BlockFetchingListener extends EventListener {
+import org.apache.spark.network.buffer.ManagedBuffer;
+public interface BlockFetchingListener extends EventListener {
/**
* Called once per successfully fetched block. After this call returns, data will be released
* automatically. If the data will be passed to another thread, the receiver should retain()
* and release() the buffer on their own, or copy the data to a new buffer.
*/
- def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
+ void onBlockFetchSuccess(String blockId, ManagedBuffer data);
/**
* Called at least once per block upon failures.
*/
- def onBlockFetchFailure(blockId: String, exception: Throwable): Unit
+ void onBlockFetchFailure(String blockId, Throwable exception);
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
new file mode 100644
index 0000000000..d45e64656a
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/** Contains all configuration necessary for locating the shuffle files of an executor. */
+public class ExecutorShuffleInfo implements Serializable {
+ /** The base set of local directories that the executor stores its shuffle files in. */
+ final String[] localDirs;
+ /** Number of subdirectories created within each localDir. */
+ final int subDirsPerLocalDir;
+ /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */
+ final String shuffleManager;
+
+ public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) {
+ this.localDirs = localDirs;
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ this.shuffleManager = shuffleManager;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("localDirs", Arrays.toString(localDirs))
+ .add("subDirsPerLocalDir", subDirsPerLocalDir)
+ .add("shuffleManager", shuffleManager)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof ExecutorShuffleInfo) {
+ ExecutorShuffleInfo o = (ExecutorShuffleInfo) other;
+ return Arrays.equals(localDirs, o.localDirs)
+ && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir)
+ && Objects.equal(shuffleManager, o.shuffleManager);
+ }
+ return false;
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
new file mode 100644
index 0000000000..a9dff31dec
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process.
+ *
+ * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered
+ * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
+ * level shuffle block.
+ */
+public class ExternalShuffleBlockHandler implements RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
+
+ private final ExternalShuffleBlockManager blockManager;
+ private final OneForOneStreamManager streamManager;
+
+ public ExternalShuffleBlockHandler() {
+ this(new OneForOneStreamManager(), new ExternalShuffleBlockManager());
+ }
+
+ /** Enables mocking out the StreamManager and BlockManager. */
+ @VisibleForTesting
+ ExternalShuffleBlockHandler(
+ OneForOneStreamManager streamManager,
+ ExternalShuffleBlockManager blockManager) {
+ this.streamManager = streamManager;
+ this.blockManager = blockManager;
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ Object msgObj = JavaUtils.deserialize(message);
+
+ logger.trace("Received message: " + msgObj);
+
+ if (msgObj instanceof OpenShuffleBlocks) {
+ OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj;
+ List<ManagedBuffer> blocks = Lists.newArrayList();
+
+ for (String blockId : msg.blockIds) {
+ blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
+ }
+ long streamId = streamManager.registerStream(blocks.iterator());
+ logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
+ callback.onSuccess(JavaUtils.serialize(
+ new ShuffleStreamHandle(streamId, msg.blockIds.length)));
+
+ } else if (msgObj instanceof RegisterExecutor) {
+ RegisterExecutor msg = (RegisterExecutor) msgObj;
+ blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
+ callback.onSuccess(new byte[0]);
+
+ } else {
+ throw new UnsupportedOperationException(String.format(
+ "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass()));
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+
+ /** For testing, clears all executors registered with "RegisterExecutor". */
+ @VisibleForTesting
+ public void clearRegisteredExecutors() {
+ blockManager.clearRegisteredExecutors();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
new file mode 100644
index 0000000000..6589889fe1
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * Manages converting shuffle BlockIds into physical segments of local files, from a process outside
+ * of Executors. Each Executor must register its own configuration about where it stores its files
+ * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated
+ * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager.
+ *
+ * Executors with shuffle file consolidation are not currently supported, as the index is stored in
+ * the Executor's memory, unlike the IndexShuffleBlockManager.
+ */
+public class ExternalShuffleBlockManager {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class);
+
+ // Map from "appId-execId" to the executor's configuration.
+ private final ConcurrentHashMap<String, ExecutorShuffleInfo> executors =
+ new ConcurrentHashMap<String, ExecutorShuffleInfo>();
+
+ // Returns an id suitable for a single executor within a single application.
+ private String getAppExecId(String appId, String execId) {
+ return appId + "-" + execId;
+ }
+
+ /** Registers a new Executor with all the configuration we need to find its shuffle files. */
+ public void registerExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ String fullId = getAppExecId(appId, execId);
+ logger.info("Registered executor {} with {}", fullId, executorInfo);
+ executors.put(fullId, executorInfo);
+ }
+
+ /**
+ * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the
+ * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make
+ * assumptions about how the hash and sort based shuffles store their data.
+ */
+ public ManagedBuffer getBlockData(String appId, String execId, String blockId) {
+ String[] blockIdParts = blockId.split("_");
+ if (blockIdParts.length < 4) {
+ throw new IllegalArgumentException("Unexpected block id format: " + blockId);
+ } else if (!blockIdParts[0].equals("shuffle")) {
+ throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId);
+ }
+ int shuffleId = Integer.parseInt(blockIdParts[1]);
+ int mapId = Integer.parseInt(blockIdParts[2]);
+ int reduceId = Integer.parseInt(blockIdParts[3]);
+
+ ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId));
+ if (executor == null) {
+ throw new RuntimeException(
+ String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
+ }
+
+ if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) {
+ return getHashBasedShuffleBlockData(executor, blockId);
+ } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) {
+ return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported shuffle manager: " + executor.shuffleManager);
+ }
+ }
+
+ /**
+ * Hash-based shuffle data is simply stored as one file per block.
+ * This logic is from FileShuffleBlockManager.
+ */
+ // TODO: Support consolidated hash shuffle files
+ private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) {
+ File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId);
+ return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length());
+ }
+
+ /**
+ * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file
+ * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager,
+ * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
+ */
+ private ManagedBuffer getSortBasedShuffleBlockData(
+ ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
+ File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.index");
+
+ DataInputStream in = null;
+ try {
+ in = new DataInputStream(new FileInputStream(indexFile));
+ in.skipBytes(reduceId * 8);
+ long offset = in.readLong();
+ long nextOffset = in.readLong();
+ return new FileSegmentManagedBuffer(
+ getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.data"),
+ offset,
+ nextOffset - offset);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to open file: " + indexFile, e);
+ } finally {
+ if (in != null) {
+ JavaUtils.closeQuietly(in);
+ }
+ }
+ }
+
+ /**
+ * Hashes a filename into the corresponding local directory, in a manner consistent with
+ * Spark's DiskBlockManager.getFile().
+ */
+ @VisibleForTesting
+ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) {
+ int hash = JavaUtils.nonNegativeHash(filename);
+ String localDir = localDirs[hash % localDirs.length];
+ int subDirId = (hash / localDirs.length) % subDirsPerLocalDir;
+ return new File(new File(localDir, String.format("%02x", subDirId)), filename);
+ }
+
+ /** For testing, clears all registered executors. */
+ @VisibleForTesting
+ void clearRegisteredExecutors() {
+ executors.clear();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
new file mode 100644
index 0000000000..cc2f6261ca
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Client for reading shuffle blocks which points to an external (outside of executor) server.
+ * This is instead of reading shuffle blocks directly from other executors (via
+ * BlockTransferService), which has the downside of losing the shuffle data if we lose the
+ * executors.
+ */
+public class ExternalShuffleClient implements ShuffleClient {
+ private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
+
+ private final TransportClientFactory clientFactory;
+ private final String appId;
+
+ public ExternalShuffleClient(TransportConf conf, String appId) {
+ TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
+ this.clientFactory = context.createClientFactory();
+ this.appId = appId;
+ }
+
+ @Override
+ public void fetchBlocks(
+ String host,
+ int port,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
+ try {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockFetcher(client, blockIds, listener)
+ .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+ } catch (Exception e) {
+ logger.error("Exception while beginning fetchBlocks", e);
+ for (String blockId : blockIds) {
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ /**
+ * Registers this executor with an external shuffle server. This registration is required to
+ * inform the shuffle server about where and how we store our shuffle files.
+ *
+ * @param host Host of shuffle server.
+ * @param port Port of shuffle server.
+ * @param execId This Executor's id.
+ * @param executorInfo Contains all info necessary for the service to find our shuffle files.
+ */
+ public void registerWithShuffleServer(
+ String host,
+ int port,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ TransportClient client = clientFactory.createClient(host, port);
+ byte[] registerExecutorMessage =
+ JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
+ client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
new file mode 100644
index 0000000000..e79420ed82
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/** Messages handled by the {@link ExternalShuffleBlockHandler}. */
+public class ExternalShuffleMessages {
+
+ /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */
+ public static class OpenShuffleBlocks implements Serializable {
+ public final String appId;
+ public final String execId;
+ public final String[] blockIds;
+
+ public OpenShuffleBlocks(String appId, String execId, String[] blockIds) {
+ this.appId = appId;
+ this.execId = execId;
+ this.blockIds = blockIds;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("blockIds", Arrays.toString(blockIds))
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof OpenShuffleBlocks) {
+ OpenShuffleBlocks o = (OpenShuffleBlocks) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Arrays.equals(blockIds, o.blockIds);
+ }
+ return false;
+ }
+ }
+
+ /** Initial registration message between an executor and its local shuffle server. */
+ public static class RegisterExecutor implements Serializable {
+ public final String appId;
+ public final String execId;
+ public final ExecutorShuffleInfo executorInfo;
+
+ public RegisterExecutor(
+ String appId,
+ String execId,
+ ExecutorShuffleInfo executorInfo) {
+ this.appId = appId;
+ this.execId = execId;
+ this.executorInfo = executorInfo;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, execId, executorInfo);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("execId", execId)
+ .add("executorInfo", executorInfo)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof RegisterExecutor) {
+ RegisterExecutor o = (RegisterExecutor) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(execId, o.execId)
+ && Objects.equal(executorInfo, o.executorInfo);
+ }
+ return false;
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
new file mode 100644
index 0000000000..39b6f30f92
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.util.Arrays;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
+ * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC
+ * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle,
+ * and Java serialization is used.
+ *
+ * Note that this typically corresponds to a
+ * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side.
+ */
+public class OneForOneBlockFetcher {
+ private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);
+
+ private final TransportClient client;
+ private final String[] blockIds;
+ private final BlockFetchingListener listener;
+ private final ChunkReceivedCallback chunkCallback;
+
+ private ShuffleStreamHandle streamHandle = null;
+
+ public OneForOneBlockFetcher(
+ TransportClient client,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ if (blockIds.length == 0) {
+ throw new IllegalArgumentException("Zero-sized blockIds array");
+ }
+ this.client = client;
+ this.blockIds = blockIds;
+ this.listener = listener;
+ this.chunkCallback = new ChunkCallback();
+ }
+
+ /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
+ private class ChunkCallback implements ChunkReceivedCallback {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ // On receipt of a chunk, pass it upwards as a block.
+ listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
+ failRemainingBlocks(remainingBlockIds, e);
+ }
+ }
+
+ /**
+ * Begins the fetching process, calling the listener with every block fetched.
+ * The given message will be serialized with the Java serializer, and the RPC must return a
+ * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
+ */
+ public void start(Object openBlocksMessage) {
+ client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ try {
+ streamHandle = JavaUtils.deserialize(response);
+ logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (int i = 0; i < streamHandle.numChunks; i++) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback);
+ }
+ } catch (Exception e) {
+ logger.error("Failed while starting block fetches", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error("Failed while starting block fetches", e);
+ failRemainingBlocks(blockIds, e);
+ }
+ });
+ }
+
+ /** Invokes the "onBlockFetchFailure" callback for every listed block id. */
+ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
+ for (String blockId : failedBlockIds) {
+ try {
+ listener.onBlockFetchFailure(blockId, e);
+ } catch (Exception e2) {
+ logger.error("Error in block fetch failure callback", e2);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
new file mode 100644
index 0000000000..9fa87c2c6e
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+/** Provides an interface for reading shuffle files, either from an Executor or external service. */
+public interface ShuffleClient {
+ /**
+ * Fetch a sequence of blocks from a remote node asynchronously,
+ *
+ * Note that this API takes a sequence so the implementation can batch requests, and does not
+ * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
+ * the data of a block is fetched, rather than waiting for all blocks to be fetched.
+ */
+ public void fetchBlocks(
+ String host,
+ int port,
+ String execId,
+ String[] blockIds,
+ BlockFetchingListener listener);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
new file mode 100644
index 0000000000..9c94691224
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+
+/**
+ * Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
+ * message. This is used by {@link OneForOneBlockFetcher}.
+ */
+public class ShuffleStreamHandle implements Serializable {
+ public final long streamId;
+ public final int numChunks;
+
+ public ShuffleStreamHandle(long streamId, int numChunks) {
+ this.streamId = streamId;
+ this.numChunks = numChunks;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, numChunks);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("numChunks", numChunks)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof ShuffleStreamHandle) {
+ ShuffleStreamHandle o = (ShuffleStreamHandle) other;
+ return Objects.equal(streamId, o.streamId)
+ && Objects.equal(numChunks, o.numChunks);
+ }
+ return false;
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
new file mode 100644
index 0000000000..7939cb4d32
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks;
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.util.JavaUtils;
+
+public class ExternalShuffleBlockHandlerSuite {
+ TransportClient client = mock(TransportClient.class);
+
+ OneForOneStreamManager streamManager;
+ ExternalShuffleBlockManager blockManager;
+ RpcHandler handler;
+
+ @Before
+ public void beforeEach() {
+ streamManager = mock(OneForOneStreamManager.class);
+ blockManager = mock(ExternalShuffleBlockManager.class);
+ handler = new ExternalShuffleBlockHandler(streamManager, blockManager);
+ }
+
+ @Test
+ public void testRegisterExecutor() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
+ byte[] registerMessage = JavaUtils.serialize(
+ new RegisterExecutor("app0", "exec1", config));
+ handler.receive(client, registerMessage, callback);
+ verify(blockManager, times(1)).registerExecutor("app0", "exec1", config);
+
+ verify(callback, times(1)).onSuccess((byte[]) any());
+ verify(callback, never()).onFailure((Throwable) any());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testOpenShuffleBlocks() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3]));
+ ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker);
+ when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker);
+ byte[] openBlocksMessage = JavaUtils.serialize(
+ new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" }));
+ handler.receive(client, openBlocksMessage, callback);
+ verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0");
+ verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1");
+
+ ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+ verify(callback, times(1)).onSuccess(response.capture());
+ verify(callback, never()).onFailure((Throwable) any());
+
+ ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue());
+ assertEquals(2, handle.numChunks);
+
+ ArgumentCaptor<Iterator> stream = ArgumentCaptor.forClass(Iterator.class);
+ verify(streamManager, times(1)).registerStream(stream.capture());
+ Iterator<ManagedBuffer> buffers = (Iterator<ManagedBuffer>) stream.getValue();
+ assertEquals(block0Marker, buffers.next());
+ assertEquals(block1Marker, buffers.next());
+ assertFalse(buffers.hasNext());
+ }
+
+ @Test
+ public void testBadMessages() {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 };
+ try {
+ handler.receive(client, unserializableMessage, callback);
+ fail("Should have thrown");
+ } catch (Exception e) {
+ // pass
+ }
+
+ byte[] unexpectedMessage = JavaUtils.serialize(
+ new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"));
+ try {
+ handler.receive(client, unexpectedMessage, callback);
+ fail("Should have thrown");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ verify(callback, never()).onSuccess((byte[]) any());
+ verify(callback, never()).onFailure((Throwable) any());
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java
new file mode 100644
index 0000000000..da54797e89
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import com.google.common.io.CharStreams;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class ExternalShuffleBlockManagerSuite {
+ static String sortBlock0 = "Hello!";
+ static String sortBlock1 = "World!";
+
+ static String hashBlock0 = "Elementary";
+ static String hashBlock1 = "Tabular";
+
+ static TestShuffleDataContext dataContext;
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ dataContext = new TestShuffleDataContext(2, 5);
+
+ dataContext.create();
+ // Write some sort and hash data.
+ dataContext.insertSortShuffleData(0, 0,
+ new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } );
+ dataContext.insertHashShuffleData(1, 0,
+ new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } );
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext.cleanup();
+ }
+
+ @Test
+ public void testBadRequests() {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ // Unregistered executor
+ try {
+ manager.getBlockData("app0", "exec1", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (RuntimeException e) {
+ assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
+ }
+
+ // Invalid shuffle manager
+ manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar"));
+ try {
+ manager.getBlockData("app0", "exec2", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (UnsupportedOperationException e) {
+ // pass
+ }
+
+ // Nonexistent shuffle block
+ manager.registerExecutor("app0", "exec3",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
+ try {
+ manager.getBlockData("app0", "exec3", "shuffle_1_1_0");
+ fail("Should have failed");
+ } catch (Exception e) {
+ // pass
+ }
+ }
+
+ @Test
+ public void testSortShuffleBlocks() throws IOException {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ manager.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager"));
+
+ InputStream block0Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(sortBlock0, block0);
+
+ InputStream block1Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(sortBlock1, block1);
+ }
+
+ @Test
+ public void testHashShuffleBlocks() throws IOException {
+ ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager();
+ manager.registerExecutor("app0", "exec0",
+ dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager"));
+
+ InputStream block0Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream();
+ String block0 = CharStreams.toString(new InputStreamReader(block0Stream));
+ block0Stream.close();
+ assertEquals(hashBlock0, block0);
+
+ InputStream block1Stream =
+ manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream();
+ String block1 = CharStreams.toString(new InputStreamReader(block1Stream));
+ block1Stream.close();
+ assertEquals(hashBlock1, block1);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
new file mode 100644
index 0000000000..b3bcf5fd68
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleIntegrationSuite {
+
+ static String APP_ID = "app-id";
+ static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager";
+ static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager";
+
+ // Executor 0 is sort-based
+ static TestShuffleDataContext dataContext0;
+ // Executor 1 is hash-based
+ static TestShuffleDataContext dataContext1;
+
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+
+ static byte[][] exec0Blocks = new byte[][] {
+ new byte[123],
+ new byte[12345],
+ new byte[1234567],
+ };
+
+ static byte[][] exec1Blocks = new byte[][] {
+ new byte[321],
+ new byte[54321],
+ };
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ Random rand = new Random();
+
+ for (byte[] block : exec0Blocks) {
+ rand.nextBytes(block);
+ }
+ for (byte[] block: exec1Blocks) {
+ rand.nextBytes(block);
+ }
+
+ dataContext0 = new TestShuffleDataContext(2, 5);
+ dataContext0.create();
+ dataContext0.insertSortShuffleData(0, 0, exec0Blocks);
+
+ dataContext1 = new TestShuffleDataContext(6, 2);
+ dataContext1.create();
+ dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
+
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ handler = new ExternalShuffleBlockHandler();
+ TransportContext transportContext = new TransportContext(conf, handler);
+ server = transportContext.createServer();
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ dataContext0.cleanup();
+ dataContext1.cleanup();
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ handler.clearRegisteredExecutors();
+ }
+
+ class FetchResult {
+ public Set<String> successBlocks;
+ public Set<String> failedBlocks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ // Fetch a set of blocks from a pre-registered executor.
+ private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception {
+ return fetchBlocks(execId, blockIds, server.getPort());
+ }
+
+ // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port,
+ // to allow connecting to invalid servers.
+ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception {
+ final FetchResult res = new FetchResult();
+ res.successBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.failedBlocks = Collections.synchronizedSet(new HashSet<String>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ final Semaphore requestsRemaining = new Semaphore(0);
+
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
+ new BlockFetchingListener() {
+ @Override
+ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ data.retain();
+ res.successBlocks.add(blockId);
+ res.buffers.add(data);
+ requestsRemaining.release();
+ }
+ }
+ }
+
+ @Override
+ public void onBlockFetchFailure(String blockId, Throwable exception) {
+ synchronized (this) {
+ if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) {
+ res.failedBlocks.add(blockId);
+ requestsRemaining.release();
+ }
+ }
+ }
+ });
+
+ if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ return res;
+ }
+
+ @Test
+ public void testFetchOneSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0]));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchThreeSort() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult exec0Fetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" });
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"),
+ exec0Fetch.successBlocks);
+ assertTrue(exec0Fetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks));
+ exec0Fetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchHash() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks);
+ assertTrue(execFetch.failedBlocks.isEmpty());
+ assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks));
+ execFetch.releaseBuffers();
+ }
+
+ @Test
+ public void testFetchWrongShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchInvalidShuffle() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager"));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongBlockId() throws Exception {
+ registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */));
+ FetchResult execFetch = fetchBlocks("exec-1",
+ new String[] { "rdd_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNonexistent() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_2_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchWrongExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ });
+ // Both still fail, as we start by checking for all block.
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchUnregisteredExecutor() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-2",
+ new String[] { "shuffle_0_0_0", "shuffle_1_0_0" });
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks);
+ }
+
+ @Test
+ public void testFetchNoServer() throws Exception {
+ registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+ FetchResult execFetch = fetchBlocks("exec-0",
+ new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
+ assertTrue(execFetch.successBlocks.isEmpty());
+ assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+ }
+
+ private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
+ ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
+ executorId, executorInfo);
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<byte[]> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i))));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
new file mode 100644
index 0000000000..c18346f696
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.util.JavaUtils;
+
+public class OneForOneBlockFetcherSuite {
+ @Test
+ public void testFetchOne() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
+ }
+
+ @Test
+ public void testFetchThree() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+ blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ for (int i = 0; i < 3; i ++) {
+ verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
+ }
+ }
+
+ @Test
+ public void testFailure() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", null);
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // Each failure will cause a failure to be invoked in all remaining block fetches.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testFailureAndSuccess() {
+ LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
+
+ BlockFetchingListener listener = fetchBlocks(blocks);
+
+ // We may call both success and failure for the same block.
+ verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+ verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
+ verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any());
+ }
+
+ @Test
+ public void testEmptyBlockFetch() {
+ try {
+ fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap());
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertEquals("Zero-sized blockIds array", e.getMessage());
+ }
+ }
+
+ /**
+ * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
+ * simply returns the given (BlockId, Block) pairs.
+ * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order
+ * that they were inserted in.
+ *
+ * If a block's buffer is "null", an exception will be thrown instead.
+ */
+ private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
+ TransportClient client = mock(TransportClient.class);
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+ OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener);
+
+ // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]);
+ RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
+ callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size())));
+ assertEquals("OpenZeBlocks", message);
+ return null;
+ }
+ }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+
+ // Respond to each chunk request with a single buffer from our blocks array.
+ final AtomicInteger expectedChunkIndex = new AtomicInteger(0);
+ final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
+ doAnswer(new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ try {
+ long streamId = (Long) invocation.getArguments()[0];
+ int myChunkIndex = (Integer) invocation.getArguments()[1];
+ assertEquals(123, streamId);
+ assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
+
+ ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
+ ManagedBuffer result = blockIterator.next();
+ if (result != null) {
+ callback.onSuccess(myChunkIndex, result);
+ } else {
+ callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Unexpected failure");
+ }
+ return null;
+ }
+ }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());
+
+ fetcher.start("OpenZeBlocks");
+ return listener;
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
new file mode 100644
index 0000000000..ee9482b49c
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.util.JavaUtils;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
+
+public class ShuffleMessagesSuite {
+ @Test
+ public void serializeOpenShuffleBlocks() {
+ OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2",
+ new String[] { "block0", "block1" });
+ OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+
+ @Test
+ public void serializeRegisterExecutor() {
+ RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
+ new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"));
+ RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+
+ @Test
+ public void serializeShuffleStreamHandle() {
+ ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16);
+ ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+ assertEquals(msg, msg2);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
new file mode 100644
index 0000000000..442b756467
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import com.google.common.io.Files;
+
+/**
+ * Manages some sort- and hash-based shuffle data, including the creation
+ * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}.
+ */
+public class TestShuffleDataContext {
+ private final String[] localDirs;
+ private final int subDirsPerLocalDir;
+
+ public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) {
+ this.localDirs = new String[numLocalDirs];
+ this.subDirsPerLocalDir = subDirsPerLocalDir;
+ }
+
+ public void create() {
+ for (int i = 0; i < localDirs.length; i ++) {
+ localDirs[i] = Files.createTempDir().getAbsolutePath();
+
+ for (int p = 0; p < subDirsPerLocalDir; p ++) {
+ new File(localDirs[i], String.format("%02x", p)).mkdirs();
+ }
+ }
+ }
+
+ public void cleanup() {
+ for (String localDir : localDirs) {
+ deleteRecursively(new File(localDir));
+ }
+ }
+
+ /** Creates reducer blocks in a sort-based data format within our local dirs. */
+ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0";
+
+ OutputStream dataStream = new FileOutputStream(
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data"));
+ DataOutputStream indexStream = new DataOutputStream(new FileOutputStream(
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index")));
+
+ long offset = 0;
+ indexStream.writeLong(offset);
+ for (byte[] block : blocks) {
+ offset += block.length;
+ dataStream.write(block);
+ indexStream.writeLong(offset);
+ }
+
+ dataStream.close();
+ indexStream.close();
+ }
+
+ /** Creates reducer blocks in a hash-based data format within our local dirs. */
+ public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException {
+ for (int i = 0; i < blocks.length; i ++) {
+ String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i;
+ Files.write(blocks[i],
+ ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId));
+ }
+ }
+
+ /**
+ * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this
+ * context's directories.
+ */
+ public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) {
+ return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager);
+ }
+
+ private static void deleteRecursively(File f) {
+ assert f != null;
+ if (f.isDirectory()) {
+ File[] children = f.listFiles();
+ if (children != null) {
+ for (File child : children) {
+ deleteRecursively(child);
+ }
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/pom.xml b/pom.xml
index 4c7806c416..61a508a0ea 100644
--- a/pom.xml
+++ b/pom.xml
@@ -92,6 +92,7 @@
<module>mllib</module>
<module>tools</module>
<module>network/common</module>
+ <module>network/shuffle</module>
<module>streaming</module>
<module>sql/catalyst</module>
<module>sql/core</module>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 77083518bb..33618f5401 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -31,11 +31,12 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
- sql, networkCommon, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt,
- streamingTwitter, streamingZeromq) =
+ sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
+ streamingMqtt, streamingTwitter, streamingZeromq) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
- "sql", "network-common", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka",
- "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _))
+ "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
+ "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
+ "streaming-zeromq").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) =
Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl")
@@ -142,7 +143,7 @@ object SparkBuild extends PomBuild {
// TODO: Add Sql to mima checks
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
- streamingFlumeSink, networkCommon).contains(x)).foreach {
+ streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}