aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMridul Muralidharan <mridul@gmail.com>2013-05-01 20:56:05 +0530
committerMridul Muralidharan <mridul@gmail.com>2013-05-01 20:56:05 +0530
commit27764a00f40391b94fa05abb11484c442607f6f7 (patch)
tree47ab45d3a1666f3f79d40f1083dc73f45a48eb6e
parentd960e7e0f83385d8f43129d53c189b3036936daf (diff)
downloadspark-27764a00f40391b94fa05abb11484c442607f6f7.tar.gz
spark-27764a00f40391b94fa05abb11484c442607f6f7.tar.bz2
spark-27764a00f40391b94fa05abb11484c442607f6f7.zip
Fix some npe introduced accidentally
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala2
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala30
2 files changed, 23 insertions, 9 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 8072c60bb7..b18248d2b5 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -117,7 +117,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
- val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env)
+ val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil))
}
cacheLocs(rdd.id)
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 7a0d6ced3e..040082e600 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -937,10 +937,16 @@ object BlockManager extends Logging {
}
}
- def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv): HashMap[String, List[String]] = {
- val blockManager = env.blockManager
- /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
- val locationBlockIds = blockManager.getLocationBlockIds(blockIds)
+ def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = {
+ // env == null and blockManagerMaster != null is used in tests
+ assert (env != null || blockManagerMaster != null)
+ val locationBlockIds: Seq[Seq[BlockManagerId]] =
+ if (env != null) {
+ val blockManager = env.blockManager
+ blockManager.getLocationBlockIds(blockIds)
+ } else {
+ blockManagerMaster.getLocations(blockIds)
+ }
// Convert from block master locations to executor locations (we need that for task scheduling)
val executorLocations = new HashMap[String, List[String]]()
@@ -950,10 +956,18 @@ object BlockManager extends Logging {
val executors = new HashSet[String]()
- for (bkLocation <- blockLocations) {
- val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host)
- executors += executorHostPort
- // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
+ if (env != null) {
+ for (bkLocation <- blockLocations) {
+ val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host)
+ executors += executorHostPort
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
+ }
+ } else {
+ // Typically while testing, etc - revert to simply using host.
+ for (bkLocation <- blockLocations) {
+ executors += bkLocation.host
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
+ }
}
executorLocations.put(blockId, executors.toSeq.toList)