aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShubham Chopra <schopra31@bloomberg.net>2016-09-30 18:24:39 -0700
committerReynold Xin <rxin@databricks.com>2016-09-30 18:24:39 -0700
commita26afd52198523dbd51dc94053424494638c7de5 (patch)
treed86b827d9c5da5246597479d9428f9c3b5dea657
parent81455a9cd963098613bad10182e3fafc83a6e352 (diff)
downloadspark-a26afd52198523dbd51dc94053424494638c7de5.tar.gz
spark-a26afd52198523dbd51dc94053424494638c7de5.tar.bz2
spark-a26afd52198523dbd51dc94053424494638c7de5.zip
[SPARK-15353][CORE] Making peer selection for block replication pluggable
## What changes were proposed in this pull request? This PR makes block replication strategies pluggable. It provides two trait that can be implemented, one that maps a host to its topology and is used in the master, and the second that helps prioritize a list of peers for block replication and would run in the executors. This patch contains default implementations of these traits that make sure current Spark behavior is unchanged. ## How was this patch tested? This patch should not change Spark behavior in any way, and was tested with unit tests for storage. Author: Shubham Chopra <schopra31@bloomberg.net> Closes #13152 from shubhamchopra/RackAwareBlockReplication.
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala167
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala112
-rw-r--r--core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala86
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala74
-rw-r--r--core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala68
9 files changed, 492 insertions, 99 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index aa29acfd70..982b83324e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,7 +20,8 @@ package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable
+import scala.collection.mutable.HashMap
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
import scala.reflect.ClassTag
@@ -44,6 +45,7 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
+
/* Class for returning a fetched block and associated metrics. */
private[spark] class BlockResult(
val data: Iterator[Any],
@@ -147,6 +149,8 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
+ private var blockReplicationPolicy: BlockReplicationPolicy = _
+
/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
@@ -160,8 +164,24 @@ private[spark] class BlockManager(
blockTransferService.init(this)
shuffleClient.init(appId)
- blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ blockReplicationPolicy = {
+ val priorityClass = conf.get(
+ "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName)
+ val clazz = Utils.classForName(priorityClass)
+ val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy]
+ logInfo(s"Using $priorityClass for block replication policy")
+ ret
+ }
+
+ val id =
+ BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None)
+
+ val idFromMaster = master.registerBlockManager(
+ id,
+ maxMemory,
+ slaveEndpoint)
+
+ blockManagerId = if (idFromMaster != null) idFromMaster else id
shuffleServerId = if (externalShuffleServiceEnabled) {
logInfo(s"external shuffle service port = $externalShuffleServicePort")
@@ -170,12 +190,12 @@ private[spark] class BlockManager(
blockManagerId
}
- master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
-
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
registerWithExternalShuffleServer()
}
+
+ logInfo(s"Initialized BlockManager: $blockManagerId")
}
private def registerWithExternalShuffleServer() {
@@ -1111,7 +1131,7 @@ private[spark] class BlockManager(
}
/**
- * Replicate block to another node. Not that this is a blocking call that returns after
+ * Replicate block to another node. Note that this is a blocking call that returns after
* the block has been replicated.
*/
private def replicate(
@@ -1119,101 +1139,78 @@ private[spark] class BlockManager(
data: ChunkedByteBuffer,
level: StorageLevel,
classTag: ClassTag[_]): Unit = {
+
val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
- val numPeersToReplicateTo = level.replication - 1
- val peersForReplication = new ArrayBuffer[BlockManagerId]
- val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
- val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
val tLevel = StorageLevel(
useDisk = level.useDisk,
useMemory = level.useMemory,
useOffHeap = level.useOffHeap,
deserialized = level.deserialized,
replication = 1)
- val startTime = System.currentTimeMillis
- val random = new Random(blockId.hashCode)
-
- var replicationFailed = false
- var failures = 0
- var done = false
-
- // Get cached list of peers
- peersForReplication ++= getPeers(forceFetch = false)
-
- // Get a random peer. Note that this selection of a peer is deterministic on the block id.
- // So assuming the list of peers does not change and no replication failures,
- // if there are multiple attempts in the same node to replicate the same block,
- // the same set of peers will be selected.
- def getRandomPeer(): Option[BlockManagerId] = {
- // If replication had failed, then force update the cached list of peers and remove the peers
- // that have been already used
- if (replicationFailed) {
- peersForReplication.clear()
- peersForReplication ++= getPeers(forceFetch = true)
- peersForReplication --= peersReplicatedTo
- peersForReplication --= peersFailedToReplicateTo
- }
- if (!peersForReplication.isEmpty) {
- Some(peersForReplication(random.nextInt(peersForReplication.size)))
- } else {
- None
- }
- }
- // One by one choose a random peer and try uploading the block to it
- // If replication fails (e.g., target peer is down), force the list of cached peers
- // to be re-fetched from driver and then pick another random peer for replication. Also
- // temporarily black list the peer for which replication failed.
- //
- // This selection of a peer and replication is continued in a loop until one of the
- // following 3 conditions is fulfilled:
- // (i) specified number of peers have been replicated to
- // (ii) too many failures in replicating to peers
- // (iii) no peer left to replicate to
- //
- while (!done) {
- getRandomPeer() match {
- case Some(peer) =>
- try {
- val onePeerStartTime = System.currentTimeMillis
- logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
- blockTransferService.uploadBlockSync(
- peer.host,
- peer.port,
- peer.executorId,
- blockId,
- new NettyManagedBuffer(data.toNetty),
- tLevel,
- classTag)
- logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms"
- .format(System.currentTimeMillis - onePeerStartTime))
- peersReplicatedTo += peer
- peersForReplication -= peer
- replicationFailed = false
- if (peersReplicatedTo.size == numPeersToReplicateTo) {
- done = true // specified number of peers have been replicated to
- }
- } catch {
- case NonFatal(e) =>
- logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
- failures += 1
- replicationFailed = true
- peersFailedToReplicateTo += peer
- if (failures > maxReplicationFailures) { // too many failures in replicating to peers
- done = true
- }
+ val numPeersToReplicateTo = level.replication - 1
+
+ val startTime = System.nanoTime
+
+ var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId]
+ var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId]
+ var numFailures = 0
+
+ var peersForReplication = blockReplicationPolicy.prioritize(
+ blockManagerId,
+ getPeers(false),
+ mutable.HashSet.empty,
+ blockId,
+ numPeersToReplicateTo)
+
+ while(numFailures <= maxReplicationFailures &&
+ !peersForReplication.isEmpty &&
+ peersReplicatedTo.size != numPeersToReplicateTo) {
+ val peer = peersForReplication.head
+ try {
+ val onePeerStartTime = System.nanoTime
+ logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
+ blockTransferService.uploadBlockSync(
+ peer.host,
+ peer.port,
+ peer.executorId,
+ blockId,
+ new NettyManagedBuffer(data.toNetty),
+ tLevel,
+ classTag)
+ logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
+ s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms")
+ peersForReplication = peersForReplication.tail
+ peersReplicatedTo += peer
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e)
+ peersFailedToReplicateTo += peer
+ // we have a failed replication, so we get the list of peers again
+ // we don't want peers we have already replicated to and the ones that
+ // have failed previously
+ val filteredPeers = getPeers(true).filter { p =>
+ !peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p)
}
- case None => // no peer left to replicate to
- done = true
+
+ numFailures += 1
+ peersForReplication = blockReplicationPolicy.prioritize(
+ blockManagerId,
+ filteredPeers,
+ peersReplicatedTo,
+ blockId,
+ numPeersToReplicateTo - peersReplicatedTo.size)
}
}
- val timeTakeMs = (System.currentTimeMillis - startTime)
+
logDebug(s"Replicating $blockId of ${data.size} bytes to " +
- s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
+ s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms")
if (peersReplicatedTo.size < numPeersToReplicateTo) {
logWarning(s"Block $blockId replicated to only " +
s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
}
+
+ logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}")
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index f255f5be63..c37a3604d2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -37,10 +37,11 @@ import org.apache.spark.util.Utils
class BlockManagerId private (
private var executorId_ : String,
private var host_ : String,
- private var port_ : Int)
+ private var port_ : Int,
+ private var topologyInfo_ : Option[String])
extends Externalizable {
- private def this() = this(null, null, 0) // For deserialization only
+ private def this() = this(null, null, 0, None) // For deserialization only
def executorId: String = executorId_
@@ -60,6 +61,8 @@ class BlockManagerId private (
def port: Int = port_
+ def topologyInfo: Option[String] = topologyInfo_
+
def isDriver: Boolean = {
executorId == SparkContext.DRIVER_IDENTIFIER ||
executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
@@ -69,24 +72,33 @@ class BlockManagerId private (
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
+ out.writeBoolean(topologyInfo_.isDefined)
+ // we only write topologyInfo if we have it
+ topologyInfo.foreach(out.writeUTF(_))
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
+ val isTopologyInfoAvailable = in.readBoolean()
+ topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString: String = s"BlockManagerId($executorId, $host, $port)"
+ override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)"
- override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
+ override def hashCode: Int =
+ ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
override def equals(that: Any): Boolean = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && host == id.host
+ executorId == id.executorId &&
+ port == id.port &&
+ host == id.host &&
+ topologyInfo == id.topologyInfo
case _ =>
false
}
@@ -101,10 +113,18 @@ private[spark] object BlockManagerId {
* @param execId ID of the executor.
* @param host Host name of the block manager.
* @param port Port of the block manager.
+ * @param topologyInfo topology information for the blockmanager, if available
+ * This can be network topology information for use while choosing peers
+ * while replicating data blocks. More information available here:
+ * [[org.apache.spark.storage.TopologyMapper]]
* @return A new [[org.apache.spark.storage.BlockManagerId]].
*/
- def apply(execId: String, host: String, port: Int): BlockManagerId =
- getCachedBlockManagerId(new BlockManagerId(execId, host, port))
+ def apply(
+ execId: String,
+ host: String,
+ port: Int,
+ topologyInfo: Option[String] = None): BlockManagerId =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))
def apply(in: ObjectInput): BlockManagerId = {
val obj = new BlockManagerId()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 8655cf10fc..7a60006891 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -50,12 +50,20 @@ class BlockManagerMaster(
logInfo("Removal of executor " + execId + " requested")
}
- /** Register the BlockManager's id with the driver. */
+ /**
+ * Register the BlockManager's id with the driver. The input BlockManagerId does not contain
+ * topology information. This information is obtained from the master and we respond with an
+ * updated BlockManagerId fleshed out with this information.
+ */
def registerBlockManager(
- blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ slaveEndpoint: RpcEndpointRef): BlockManagerId = {
logInfo(s"Registering BlockManager $blockManagerId")
- tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
- logInfo(s"Registered BlockManager $blockManagerId")
+ val updatedId = driverEndpoint.askWithRetry[BlockManagerId](
+ RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
+ logInfo(s"Registered BlockManager $updatedId")
+ updatedId
}
def updateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 8fa1215011..145c434a4f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -55,10 +55,21 @@ class BlockManagerMasterEndpoint(
private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
+ private val topologyMapper = {
+ val topologyMapperClassName = conf.get(
+ "spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName)
+ val clazz = Utils.classForName(topologyMapperClassName)
+ val mapper =
+ clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper]
+ logInfo(s"Using $topologyMapperClassName for getting topology information")
+ mapper
+ }
+
+ logInfo("BlockManagerMasterEndpoint up")
+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
- register(blockManagerId, maxMemSize, slaveEndpoint)
- context.reply(true)
+ context.reply(register(blockManagerId, maxMemSize, slaveEndpoint))
case _updateBlockInfo @
UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
@@ -298,7 +309,21 @@ class BlockManagerMasterEndpoint(
).map(_.flatten.toSeq)
}
- private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
+ /**
+ * Returns the BlockManagerId with topology information populated, if available.
+ */
+ private def register(
+ idWithoutTopologyInfo: BlockManagerId,
+ maxMemSize: Long,
+ slaveEndpoint: RpcEndpointRef): BlockManagerId = {
+ // the dummy id is not expected to contain the topology information.
+ // we get that info here and respond back with a more fleshed out block manager id
+ val id = BlockManagerId(
+ idWithoutTopologyInfo.executorId,
+ idWithoutTopologyInfo.host,
+ idWithoutTopologyInfo.port,
+ topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host))
+
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -318,6 +343,7 @@ class BlockManagerMasterEndpoint(
id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
+ id
}
private def updateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
new file mode 100644
index 0000000000..bf087af16a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+
+/**
+ * ::DeveloperApi::
+ * BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for
+ * replicating blocks. BlockManager will replicate to each peer returned in order until the
+ * desired replication order is reached. If a replication fails, prioritize() will be called
+ * again to get a fresh prioritization.
+ */
+@DeveloperApi
+trait BlockReplicationPolicy {
+
+ /**
+ * Method to prioritize a bunch of candidate peers of a block
+ *
+ * @param blockManagerId Id of the current BlockManager for self identification
+ * @param peers A list of peers of a BlockManager
+ * @param peersReplicatedTo Set of peers already replicated to
+ * @param blockId BlockId of the block being replicated. This can be used as a source of
+ * randomness if needed.
+ * @param numReplicas Number of peers we need to replicate to
+ * @return A prioritized list of peers. Lower the index of a peer, higher its priority.
+ * This returns a list of size at most `numPeersToReplicateTo`.
+ */
+ def prioritize(
+ blockManagerId: BlockManagerId,
+ peers: Seq[BlockManagerId],
+ peersReplicatedTo: mutable.HashSet[BlockManagerId],
+ blockId: BlockId,
+ numReplicas: Int): List[BlockManagerId]
+}
+
+@DeveloperApi
+class RandomBlockReplicationPolicy
+ extends BlockReplicationPolicy
+ with Logging {
+
+ /**
+ * Method to prioritize a bunch of candidate peers of a block. This is a basic implementation,
+ * that just makes sure we put blocks on different hosts, if possible
+ *
+ * @param blockManagerId Id of the current BlockManager for self identification
+ * @param peers A list of peers of a BlockManager
+ * @param peersReplicatedTo Set of peers already replicated to
+ * @param blockId BlockId of the block being replicated. This can be used as a source of
+ * randomness if needed.
+ * @return A prioritized list of peers. Lower the index of a peer, higher its priority
+ */
+ override def prioritize(
+ blockManagerId: BlockManagerId,
+ peers: Seq[BlockManagerId],
+ peersReplicatedTo: mutable.HashSet[BlockManagerId],
+ blockId: BlockId,
+ numReplicas: Int): List[BlockManagerId] = {
+ val random = new Random(blockId.hashCode)
+ logDebug(s"Input peers : ${peers.mkString(", ")}")
+ val prioritizedPeers = if (peers.size > numReplicas) {
+ getSampleIds(peers.size, numReplicas, random).map(peers(_))
+ } else {
+ if (peers.size < numReplicas) {
+ logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.")
+ }
+ random.shuffle(peers).toList
+ }
+ logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}")
+ prioritizedPeers
+ }
+
+ /**
+ * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while
+ * minimizing space usage
+ * [[http://math.stackexchange.com/questions/178690/
+ * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]]
+ *
+ * @param n total number of indices
+ * @param m number of samples needed
+ * @param r random number generator
+ * @return list of m random unique indices
+ */
+ private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = {
+ val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) =>
+ val t = r.nextInt(i) + 1
+ if (set.contains(t)) set + i else set + t
+ }
+ // we shuffle the result to ensure a random arrangement within the sample
+ // to avoid any bias from set implementations
+ r.shuffle(indices.map(_ - 1).toList)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
new file mode 100644
index 0000000000..a0f0fdef8e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * ::DeveloperApi::
+ * TopologyMapper provides topology information for a given host
+ * @param conf SparkConf to get required properties, if needed
+ */
+@DeveloperApi
+abstract class TopologyMapper(conf: SparkConf) {
+ /**
+ * Gets the topology information given the host name
+ *
+ * @param hostname Hostname
+ * @return topology information for the given hostname. One can use a 'topology delimiter'
+ * to make this topology information nested.
+ * For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter,
+ * ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host.
+ * This function only returns the topology information without the hostname.
+ * This information can be used when choosing executors for block replication
+ * to discern executors from a different rack than a candidate executor, for example.
+ *
+ * An implementation can choose to use empty strings or None in case topology info
+ * is not available. This would imply that all such executors belong to the same rack.
+ */
+ def getTopologyForHost(hostname: String): Option[String]
+}
+
+/**
+ * A TopologyMapper that assumes all nodes are in the same rack
+ */
+@DeveloperApi
+class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+ override def getTopologyForHost(hostname: String): Option[String] = {
+ logDebug(s"Got a request for $hostname")
+ None
+ }
+}
+
+/**
+ * A simple file based topology mapper. This expects topology information provided as a
+ * [[java.util.Properties]] file. The name of the file is obtained from SparkConf property
+ * `spark.storage.replication.topologyFile`. To use this topology mapper, set the
+ * `spark.storage.replication.topologyMapper` property to
+ * [[org.apache.spark.storage.FileBasedTopologyMapper]]
+ * @param conf SparkConf object
+ */
+@DeveloperApi
+class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+ val topologyFile = conf.getOption("spark.storage.replication.topologyFile")
+ require(topologyFile.isDefined, "Please specify topology file via " +
+ "spark.storage.replication.topologyFile for FileBasedTopologyMapper.")
+ val topologyMap = Utils.getPropertiesFromFile(topologyFile.get)
+
+ override def getTopologyForHost(hostname: String): Option[String] = {
+ val topology = topologyMap.get(hostname)
+ if (topology.isDefined) {
+ logDebug(s"$hostname -> ${topology.get}")
+ } else {
+ logWarning(s"$hostname does not have any topology information")
+ }
+ topology
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index e1c1787cbd..f4bfdc2fd6 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -346,6 +346,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite
}
}
+
+
/**
* Test replication of blocks with different storage levels (various combinations of
* memory, disk & serialization). For each storage level, this function tests every store
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
new file mode 100644
index 0000000000..800c3899f1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+
+import org.scalatest.{BeforeAndAfter, Matchers}
+
+import org.apache.spark.{LocalSparkContext, SparkFunSuite}
+
+class BlockReplicationPolicySuite extends SparkFunSuite
+ with Matchers
+ with BeforeAndAfter
+ with LocalSparkContext {
+
+ // Implicitly convert strings to BlockIds for test clarity.
+ private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+
+ /**
+ * Test if we get the required number of peers when using random sampling from
+ * RandomBlockReplicationPolicy
+ */
+ test(s"block replication - random block replication policy") {
+ val numBlockManagers = 10
+ val storeSize = 1000
+ val blockManagers = (1 to numBlockManagers).map { i =>
+ BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
+ }
+ val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
+ val replicationPolicy = new RandomBlockReplicationPolicy
+ val blockId = "test-block"
+
+ (1 to 10).foreach {numReplicas =>
+ logDebug(s"Num replicas : $numReplicas")
+ val randomPeers = replicationPolicy.prioritize(
+ candidateBlockManager,
+ blockManagers,
+ mutable.HashSet.empty[BlockManagerId],
+ blockId,
+ numReplicas
+ )
+ logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
+ assert(randomPeers.toSet.size === numReplicas)
+
+ // choosing n peers out of n
+ val secondPass = replicationPolicy.prioritize(
+ candidateBlockManager,
+ randomPeers,
+ mutable.HashSet.empty[BlockManagerId],
+ blockId,
+ numReplicas
+ )
+ logDebug(s"Random peers : ${secondPass.mkString(", ")}")
+ assert(secondPass.toSet.size === numReplicas)
+ }
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala
new file mode 100644
index 0000000000..bbd252d7be
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{File, FileOutputStream}
+
+import org.scalatest.{BeforeAndAfter, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.util.Utils
+
+class TopologyMapperSuite extends SparkFunSuite
+ with Matchers
+ with BeforeAndAfter
+ with LocalSparkContext {
+
+ test("File based Topology Mapper") {
+ val numHosts = 100
+ val numRacks = 4
+ val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
+ val propsFile = createPropertiesFile(props)
+
+ val sparkConf = (new SparkConf(false))
+ sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
+ val topologyMapper = new FileBasedTopologyMapper(sparkConf)
+
+ props.foreach {case (host, topology) =>
+ val obtainedTopology = topologyMapper.getTopologyForHost(host)
+ assert(obtainedTopology.isDefined)
+ assert(obtainedTopology.get === topology)
+ }
+
+ // we get None for hosts not in the file
+ assert(topologyMapper.getTopologyForHost("host").isEmpty)
+
+ cleanup(propsFile)
+ }
+
+ def createPropertiesFile(props: Map[String, String]): File = {
+ val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
+ val fileOS = new FileOutputStream(testFile)
+ props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
+ fileOS.close
+ testFile
+ }
+
+ def cleanup(testFile: File): Unit = {
+ testFile.getParentFile.listFiles.filter { file =>
+ file.getName.startsWith(testFile.getName)
+ }.foreach { _.delete() }
+ }
+
+}