From 010b03ed52f35fd4d426d522f8a9927ddc579209 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 18 Aug 2015 22:30:13 -0700 Subject: [SPARK-9952] Fix N^2 loop when DAGScheduler.getPreferredLocsInternal accesses cacheLocs In Scala, `Seq.fill` always seems to return a List. Accessing a list by index is an O(N) operation. Thus, the following code will be really slow (~10 seconds on my machine): ```scala val numItems = 100000 val s = Seq.fill(numItems)(1) for (i <- 0 until numItems) s(i) ``` It turns out that we had a loop like this in DAGScheduler code, although it's a little tricky to spot. In `getPreferredLocsInternal`, there's a call to `getCacheLocs(rdd)(partition)`. The `getCacheLocs` call returns a Seq. If this Seq is a List and the RDD contains many partitions, then indexing into this list will cost O(partitions). Thus, when we loop over our tasks to compute their individual preferred locations we implicitly perform an N^2 loop, reducing scheduling throughput. This patch fixes this by replacing `Seq` with `Array`. Author: Josh Rosen Closes #8178 from JoshRosen/dagscheduler-perf. --- .../org/apache/spark/scheduler/DAGScheduler.scala | 22 +++++++++++----------- .../apache/spark/storage/BlockManagerMaster.scala | 5 +++-- .../spark/storage/BlockManagerMasterEndpoint.scala | 3 ++- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 4 ++-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dadf83a382..684db66467 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -111,7 +111,7 @@ class DAGScheduler( * * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). */ - private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]] + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -205,12 +205,12 @@ class DAGScheduler( } private[scheduler] - def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { // Note: if the storage level is NONE, we don't need to get locations from block manager. - val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { - Seq.fill(rdd.partitions.size)(Nil) + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) } else { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] @@ -302,12 +302,12 @@ class DAGScheduler( shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd - val numTasks = rdd.partitions.size + val numTasks = rdd.partitions.length val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { + for (i <- 0 until locs.length) { stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) @@ -315,7 +315,7 @@ class DAGScheduler( // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage } @@ -566,7 +566,7 @@ class DAGScheduler( properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray + val partitions = (0 until rdd.partitions.length).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) @@ -718,7 +718,7 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -1039,7 +1039,7 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head), changeEpoch = true) clearCacheLocs() @@ -1169,7 +1169,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 2a11f371b9..f45bff34d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -69,8 +69,9 @@ class BlockManagerMaster( } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + GetLocationsMultipleBlockIds(blockIds)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5dc0c537cb..6fec524070 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -372,7 +372,8 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds( + blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a063596d3e..2e8688cf41 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -133,11 +133,11 @@ class DAGSchedulerSuite val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null, conf, true) { - override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). getOrElse(Seq()) - }.toSeq + }.toIndexedSeq } override def removeExecutor(execId: String) { // don't need to propagate to the driver, which we don't have -- cgit v1.2.3