diff options
author | Mridul Muralidharan <mridul@gmail.com> | 2013-05-01 20:56:05 +0530 |
---|---|---|
committer | Mridul Muralidharan <mridul@gmail.com> | 2013-05-01 20:56:05 +0530 |
commit | 27764a00f40391b94fa05abb11484c442607f6f7 (patch) | |
tree | 47ab45d3a1666f3f79d40f1083dc73f45a48eb6e | |
parent | d960e7e0f83385d8f43129d53c189b3036936daf (diff) | |
download | spark-27764a00f40391b94fa05abb11484c442607f6f7.tar.gz spark-27764a00f40391b94fa05abb11484c442607f6f7.tar.bz2 spark-27764a00f40391b94fa05abb11484c442607f6f7.zip |
Fix some npe introduced accidentally
-rw-r--r-- | core/src/main/scala/spark/scheduler/DAGScheduler.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManager.scala | 30 |
2 files changed, 23 insertions, 9 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8072c60bb7..b18248d2b5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -117,7 +117,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env) + val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil)) } cacheLocs(rdd.id) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a0d6ced3e..040082e600 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -937,10 +937,16 @@ object BlockManager extends Logging { } } - def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv): HashMap[String, List[String]] = { - val blockManager = env.blockManager - /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locationBlockIds = blockManager.getLocationBlockIds(blockIds) + def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = { + // env == null and blockManagerMaster != null is used in tests + assert (env != null || blockManagerMaster != null) + val locationBlockIds: Seq[Seq[BlockManagerId]] = + if (env != null) { + val blockManager = env.blockManager + blockManager.getLocationBlockIds(blockIds) + } else { + blockManagerMaster.getLocations(blockIds) + } // Convert from block master locations to executor locations (we need that for task scheduling) val executorLocations = new HashMap[String, List[String]]() @@ -950,10 +956,18 @@ object BlockManager extends Logging { val executors = new HashSet[String]() - for (bkLocation <- blockLocations) { - val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) - executors += executorHostPort - // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + if (env != null) { + for (bkLocation <- blockLocations) { + val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) + executors += executorHostPort + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } + } else { + // Typically while testing, etc - revert to simply using host. + for (bkLocation <- blockLocations) { + executors += bkLocation.host + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } } executorLocations.put(blockId, executors.toSeq.toList) |