From 27764a00f40391b94fa05abb11484c442607f6f7 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 20:56:05 +0530 Subject: Fix some npe introduced accidentally --- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 30 ++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8072c60bb7..b18248d2b5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -117,7 +117,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env) + val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil)) } cacheLocs(rdd.id) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a0d6ced3e..040082e600 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -937,10 +937,16 @@ object BlockManager extends Logging { } } - def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv): HashMap[String, List[String]] = { - val blockManager = env.blockManager - /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locationBlockIds = blockManager.getLocationBlockIds(blockIds) + def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = { + // env == null and blockManagerMaster != null is used in tests + assert (env != null || blockManagerMaster != null) + val locationBlockIds: Seq[Seq[BlockManagerId]] = + if (env != null) { + val blockManager = env.blockManager + blockManager.getLocationBlockIds(blockIds) + } else { + blockManagerMaster.getLocations(blockIds) + } // Convert from block master locations to executor locations (we need that for task scheduling) val executorLocations = new HashMap[String, List[String]]() @@ -950,10 +956,18 @@ object BlockManager extends Logging { val executors = new HashSet[String]() - for (bkLocation <- blockLocations) { - val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) - executors += executorHostPort - // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + if (env != null) { + for (bkLocation <- blockLocations) { + val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) + executors += executorHostPort + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } + } else { + // Typically while testing, etc - revert to simply using host. + for (bkLocation <- blockLocations) { + executors += bkLocation.host + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } } executorLocations.put(blockId, executors.toSeq.toList) -- cgit v1.2.3