diff options
Diffstat (limited to 'core/src/main/scala/org/apache')
11 files changed, 194 insertions, 354 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala deleted file mode 100644 index 2b456facd9..0000000000 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable - -import org.apache.spark.rdd.RDD -import org.apache.spark.storage._ -import org.apache.spark.util.CompletionIterator - -/** - * Spark class responsible for passing RDDs partition contents to the BlockManager and making - * sure a node doesn't load two copies of an RDD at once. - */ -private[spark] class CacheManager(blockManager: BlockManager) extends Logging { - - /** Keys of RDD partitions that are being computed/loaded. */ - private val loading = new mutable.HashSet[RDDBlockId] - - /** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T]( - rdd: RDD[T], - partition: Partition, - context: TaskContext, - storageLevel: StorageLevel): Iterator[T] = { - - val key = RDDBlockId(rdd.id, partition.index) - logDebug(s"Looking for partition $key") - blockManager.get(key) match { - case Some(blockResult) => - // Partition is already materialized, so just return its values - val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesReadInternal(blockResult.bytes) - - val iter = blockResult.data.asInstanceOf[Iterator[T]] - - new InterruptibleIterator[T](context, iter) { - override def next(): T = { - existingMetrics.incRecordsReadInternal(1) - delegate.next() - } - } - case None => - // Acquire a lock for loading this partition - // If another thread already holds the lock, wait for it to finish return its results - val storedValues = acquireLockForPartition[T](key) - if (storedValues.isDefined) { - return new InterruptibleIterator[T](context, storedValues.get) - } - - // Otherwise, we have to load the partition ourselves - try { - logInfo(s"Partition $key not found, computing it") - val computedValues = rdd.computeOrReadCheckpoint(partition, context) - val cachedValues = putInBlockManager(key, computedValues, storageLevel) - new InterruptibleIterator(context, cachedValues) - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } - - /** - * Acquire a loading lock for the partition identified by the given block ID. - * - * If the lock is free, just acquire it and return None. Otherwise, another thread is already - * loading the partition, so we wait for it to finish and return the values loaded by the thread. - */ - private def acquireLockForPartition[T](id: RDDBlockId): Option[Iterator[T]] = { - loading.synchronized { - if (!loading.contains(id)) { - // If the partition is free, acquire its lock to compute its value - loading.add(id) - None - } else { - // Otherwise, wait for another thread to finish and return its result - logInfo(s"Another thread is loading $id, waiting for it to finish...") - while (loading.contains(id)) { - try { - loading.wait() - } catch { - case e: Exception => - logWarning(s"Exception while waiting for another thread to load $id", e) - } - } - logInfo(s"Finished waiting for $id") - val values = blockManager.get(id) - if (!values.isDefined) { - /* The block is not guaranteed to exist even after the other thread has finished. - * For instance, the block could be evicted after it was put, but before our get. - * In this case, we still need to load the partition ourselves. */ - logInfo(s"Whoever was loading $id failed; we'll try it ourselves") - loading.add(id) - } - values.map(_.data.asInstanceOf[Iterator[T]]) - } - } - } - - /** - * Cache the values of a partition, keeping track of any updates in the storage statuses of - * other blocks along the way. - * - * The effective storage level refers to the level that actually specifies BlockManager put - * behavior, not the level originally specified by the user. This is mainly for forcing a - * MEMORY_AND_DISK partition to disk if there is not enough room to unroll the partition, - * while preserving the original semantics of the RDD as specified by the application. - */ - private def putInBlockManager[T]( - key: BlockId, - values: Iterator[T], - level: StorageLevel, - effectiveStorageLevel: Option[StorageLevel] = None): Iterator[T] = { - - val putLevel = effectiveStorageLevel.getOrElse(level) - if (!putLevel.useMemory) { - /* - * This RDD is not to be cached in memory, so we can just pass the computed values as an - * iterator directly to the BlockManager rather than first fully unrolling it in memory. - */ - blockManager.putIterator(key, values, level, tellMaster = true, effectiveStorageLevel) - blockManager.get(key) match { - case Some(v) => v.data.asInstanceOf[Iterator[T]] - case None => - logInfo(s"Failure to store $key") - throw new BlockException(key, s"Block manager failed to return cached value for $key!") - } - } else { - /* - * This RDD is to be cached in memory. In this case we cannot pass the computed values - * to the BlockManager as an iterator and expect to read it back later. This is because - * we may end up dropping a partition from memory store before getting it back. - * - * In addition, we must be careful to not unroll the entire partition in memory at once. - * Otherwise, we may cause an OOM exception if the JVM does not have enough space for this - * single partition. Instead, we unroll the values cautiously, potentially aborting and - * dropping the partition to disk if applicable. - */ - blockManager.memoryStore.unrollSafely(key, values) match { - case Left(arr) => - // We have successfully unrolled the entire partition, so cache it in memory - blockManager.putArray(key, arr, level, tellMaster = true, effectiveStorageLevel) - CompletionIterator[T, Iterator[T]]( - arr.iterator.asInstanceOf[Iterator[T]], - blockManager.releaseLock(key)) - case Right(it) => - // There is not enough space to cache this partition in memory - val returnValues = it.asInstanceOf[Iterator[T]] - if (putLevel.useDisk) { - logWarning(s"Persisting partition $key to disk instead.") - val diskOnlyLevel = StorageLevel(useDisk = true, useMemory = false, - useOffHeap = false, deserialized = false, putLevel.replication) - putInBlockManager[T](key, returnValues, level, Some(diskOnlyLevel)) - } else { - returnValues - } - } - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 204f7356f7..b3b3729625 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -56,7 +56,6 @@ class SparkEnv ( private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, - val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, @@ -333,8 +332,6 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) - val cacheManager = new CacheManager(blockManager) - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -371,7 +368,6 @@ object SparkEnv extends Logging { rpcEnv, serializer, closureSerializer, - cacheManager, mapOutputTracker, shuffleManager, broadcastManager, diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index c08f87a8b4..dabc810018 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -99,18 +99,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. val blockManager = SparkEnv.get.blockManager - if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { - blockManager.releaseLock(broadcastId) - } else { + if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { throw new SparkException(s"Failed to store $broadcastId in BlockManager") } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => val pieceId = BroadcastBlockId(id, "piece" + i) - if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) { - blockManager.releaseLock(pieceId) - } else { + if (!blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } } @@ -130,22 +126,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // First try getLocalBytes because there is a chance that previous attempts to fetch the // broadcast blocks have already fetched some of the blocks. In that case, some blocks // would be available locally (on this executor). - def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) - def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => - // If we found the block from remote executors/driver's BlockManager, put the block - // in this executor's BlockManager. - if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { - throw new SparkException( - s"Failed to store $pieceId of $broadcastId in local BlockManager") - } - block + bm.getLocalBytes(pieceId) match { + case Some(block) => + blocks(pid) = block + releaseLock(pieceId) + case None => + bm.getRemoteBytes(pieceId) match { + case Some(b) => + // We found the block from remote executors/driver's BlockManager, so put the block + // in this executor's BlockManager. + if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + blocks(pid) = b + case None => + throw new SparkException(s"Failed to get $pieceId of $broadcastId") + } } - val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( - throw new SparkException(s"Failed to get $pieceId of $broadcastId")) - // At this point we are guaranteed to hold a read lock, since we either got the block locally - // or stored the remotely-fetched block and automatically downgraded the write lock. - blocks(pid) = block - releaseLock(pieceId) } blocks } @@ -191,9 +189,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. val storageLevel = StorageLevel.MEMORY_AND_DISK - if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - releaseLock(broadcastId) - } else { + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { throw new SparkException(s"Failed to store $broadcastId in BlockManager") } obj diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a959f200d4..e88d6cd089 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -292,11 +292,8 @@ private[spark] class Executor( ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) } else if (resultSize >= maxRpcMessageSize) { val blockId = TaskResultBlockId(taskId) - val putSucceeded = env.blockManager.putBytes( + env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - if (putSucceeded) { - env.blockManager.releaseLock(blockId) - } logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index e4246df83a..e86933b948 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -66,10 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) - val putSucceeded = blockManager.putBlockData(blockId, data, level) - if (putSucceeded) { - blockManager.releaseLock(blockId) - } + blockManager.putBlockData(blockId, data, level) responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6a6ad2d75a..e5fdebc65d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, @@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag]( */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } @@ -315,6 +315,35 @@ abstract class RDD[T: ClassTag]( } /** + * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. + */ + private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = { + val blockId = RDDBlockId(id, partition.index) + var readCachedBlock = true + // This method is called on executors, so we need call SparkEnv.get instead of sc.env. + SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => { + readCachedBlock = false + computeOrReadCheckpoint(partition, context) + }) match { + case Left(blockResult) => + if (readCachedBlock) { + val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) + existingMetrics.incBytesReadInternal(blockResult.bytes) + new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { + override def next(): T = { + existingMetrics.incRecordsReadInternal(1) + delegate.next() + } + } + } else { + new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) + } + case Right(iter) => + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]]) + } + } + + /** * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 0eda97e58d..b23244ad51 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -71,27 +71,13 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea _writerTask = t checkInvariants() } - private[this] var _writerTask: Long = 0 - - /** - * True if this block has been removed from the BlockManager and false otherwise. - * This field is used to communicate block deletion to blocked readers / writers (see its usage - * in [[BlockInfoManager]]). - */ - def removed: Boolean = _removed - def removed_=(r: Boolean): Unit = { - _removed = r - checkInvariants() - } - private[this] var _removed: Boolean = false + private[this] var _writerTask: Long = BlockInfo.NO_WRITER private def checkInvariants(): Unit = { // A block's reader count must be non-negative: assert(_readerCount >= 0) // A block is either locked for reading or for writing, but not for both at the same time: assert(_readerCount == 0 || _writerTask == BlockInfo.NO_WRITER) - // If a block is removed then it is not locked: - assert(!_removed || (_readerCount == 0 && _writerTask == BlockInfo.NO_WRITER)) } checkInvariants() @@ -195,16 +181,22 @@ private[storage] class BlockInfoManager extends Logging { blockId: BlockId, blocking: Boolean = true): Option[BlockInfo] = synchronized { logTrace(s"Task $currentTaskAttemptId trying to acquire read lock for $blockId") - infos.get(blockId).map { info => - while (info.writerTask != BlockInfo.NO_WRITER) { - if (blocking) wait() else return None + do { + infos.get(blockId) match { + case None => return None + case Some(info) => + if (info.writerTask == BlockInfo.NO_WRITER) { + info.readerCount += 1 + readLocksByTask(currentTaskAttemptId).add(blockId) + logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId") + return Some(info) + } } - if (info.removed) return None - info.readerCount += 1 - readLocksByTask(currentTaskAttemptId).add(blockId) - logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId") - info - } + if (blocking) { + wait() + } + } while (blocking) + None } /** @@ -226,21 +218,25 @@ private[storage] class BlockInfoManager extends Logging { blockId: BlockId, blocking: Boolean = true): Option[BlockInfo] = synchronized { logTrace(s"Task $currentTaskAttemptId trying to acquire write lock for $blockId") - infos.get(blockId).map { info => - if (info.writerTask == currentTaskAttemptId) { - throw new IllegalStateException( - s"Task $currentTaskAttemptId has already locked $blockId for writing") - } else { - while (info.writerTask != BlockInfo.NO_WRITER || info.readerCount != 0) { - if (blocking) wait() else return None - } - if (info.removed) return None + do { + infos.get(blockId) match { + case None => return None + case Some(info) => + if (info.writerTask == currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId has already locked $blockId for writing") + } else if (info.writerTask == BlockInfo.NO_WRITER && info.readerCount == 0) { + info.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") + return Some(info) + } } - info.writerTask = currentTaskAttemptId - writeLocksByTask.addBinding(currentTaskAttemptId, blockId) - logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") - info - } + if (blocking) { + wait() + } + } while (blocking) + None } /** @@ -306,29 +302,30 @@ private[storage] class BlockInfoManager extends Logging { } /** - * Atomically create metadata for a block and acquire a write lock for it, if it doesn't already - * exist. + * Attempt to acquire the appropriate lock for writing a new block. + * + * This enforces the first-writer-wins semantics. If we are the first to write the block, + * then just go ahead and acquire the write lock. Otherwise, if another thread is already + * writing the block, then we wait for the write to finish before acquiring the read lock. * - * @param blockId the block id. - * @param newBlockInfo the block info for the new block. * @return true if the block did not already exist, false otherwise. If this returns false, then - * no new locks are acquired. If this returns true, a write lock on the new block will - * be held. + * a read lock on the existing block will be held. If this returns true, a write lock on + * the new block will be held. */ def lockNewBlockForWriting( blockId: BlockId, newBlockInfo: BlockInfo): Boolean = synchronized { logTrace(s"Task $currentTaskAttemptId trying to put $blockId") - if (!infos.contains(blockId)) { - infos(blockId) = newBlockInfo - newBlockInfo.writerTask = currentTaskAttemptId - writeLocksByTask.addBinding(currentTaskAttemptId, blockId) - logTrace(s"Task $currentTaskAttemptId successfully locked new block $blockId") - true - } else { - logTrace(s"Task $currentTaskAttemptId did not create and lock block $blockId " + - s"because that block already exists") - false + lockForReading(blockId) match { + case Some(info) => + // Block already exists. This could happen if another thread races with us to compute + // the same block. In this case, just keep the read lock and return. + false + case None => + // Block does not yet exist or is removed, so we are free to acquire the write lock + infos(blockId) = newBlockInfo + lockForWriting(blockId) + true } } @@ -418,7 +415,6 @@ private[storage] class BlockInfoManager extends Logging { infos.remove(blockId) blockInfo.readerCount = 0 blockInfo.writerTask = BlockInfo.NO_WRITER - blockInfo.removed = true } case None => throw new IllegalArgumentException( @@ -434,7 +430,6 @@ private[storage] class BlockInfoManager extends Logging { infos.valuesIterator.foreach { blockInfo => blockInfo.readerCount = 0 blockInfo.writerTask = BlockInfo.NO_WRITER - blockInfo.removed = true } infos.clear() readLocksByTask.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 29124b368e..b59191b291 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -44,8 +44,7 @@ import org.apache.spark.util._ private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues -private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues -private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues +private[spark] case class IteratorValues(iterator: () => Iterator[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -648,8 +647,38 @@ private[spark] class BlockManager( } /** - * @return true if the block was stored or false if the block was already stored or an - * error occurred. + * Retrieve the given block if it exists, otherwise call the provided `makeIterator` method + * to compute the block, persist it, and return its values. + * + * @return either a BlockResult if the block was successfully cached, or an iterator if the block + * could not be cached. + */ + def getOrElseUpdate( + blockId: BlockId, + level: StorageLevel, + makeIterator: () => Iterator[Any]): Either[BlockResult, Iterator[Any]] = { + // Initially we hold no locks on this block. + doPut(blockId, IteratorValues(makeIterator), level, keepReadLock = true) match { + case None => + // doPut() didn't hand work back to us, so the block already existed or was successfully + // stored. Therefore, we now hold a read lock on the block. + val blockResult = get(blockId).getOrElse { + // Since we held a read lock between the doPut() and get() calls, the block should not + // have been evicted, so get() not returning the block indicates some internal error. + releaseLock(blockId) + throw new SparkException(s"get() failed for block $blockId even though we held a lock") + } + Left(blockResult) + case Some(failedPutResult) => + // The put failed, likely because the data was too large to fit in memory and could not be + // dropped to disk. Therefore, we need to pass the input iterator back to the caller so + // that they can decide what to do with the values (e.g. process them without caching). + Right(failedPutResult.data.left.get) + } + } + + /** + * @return true if the block was stored or false if an error occurred. */ def putIterator( blockId: BlockId, @@ -658,7 +687,7 @@ private[spark] class BlockManager( tellMaster: Boolean = true, effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(values != null, "Values is null") - doPut(blockId, IteratorValues(values), level, tellMaster, effectiveStorageLevel) + doPut(blockId, IteratorValues(() => values), level, tellMaster, effectiveStorageLevel).isEmpty } /** @@ -679,26 +708,9 @@ private[spark] class BlockManager( } /** - * Put a new block of values to the block manager. - * - * @return true if the block was stored or false if the block was already stored or an - * error occurred. - */ - def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { - require(values != null, "Values is null") - doPut(blockId, ArrayValues(values), level, tellMaster, effectiveStorageLevel) - } - - /** * Put a new block of serialized bytes to the block manager. * - * @return true if the block was stored or false if the block was already stored or an - * error occurred. + * @return true if the block was stored or false if an error occurred. */ def putBytes( blockId: BlockId, @@ -707,26 +719,32 @@ private[spark] class BlockManager( tellMaster: Boolean = true, effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { require(bytes != null, "Bytes is null") - doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel) + doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel).isEmpty } /** * Put the given block according to the given level in one of the block stores, replicating * the values if necessary. * - * The effective storage level refers to the level according to which the block will actually be - * handled. This allows the caller to specify an alternate behavior of doPut while preserving - * the original level specified by the user. + * If the block already exists, this method will not overwrite it. * - * @return true if the block was stored or false if the block was already stored or an - * error occurred. + * @param effectiveStorageLevel the level according to which the block will actually be handled. + * This allows the caller to specify an alternate behavior of doPut + * while preserving the original level specified by the user. + * @param keepReadLock if true, this method will hold the read lock when it returns (even if the + * block already exists). If false, this method will hold no locks when it + * returns. + * @return `Some(PutResult)` if the block did not exist and could not be successfully cached, + * or None if the block already existed or was successfully stored (fully consuming + * the input data / input iterator). */ private def doPut( blockId: BlockId, data: BlockValues, level: StorageLevel, tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Boolean = { + effectiveStorageLevel: Option[StorageLevel] = None, + keepReadLock: Boolean = false): Option[PutResult] = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") @@ -743,7 +761,11 @@ private[spark] class BlockManager( newInfo } else { logWarning(s"Block $blockId already exists on this machine; not re-adding it") - return false + if (!keepReadLock) { + // lockNewBlockForWriting returned a read lock on the existing block, so we must free it: + releaseLock(blockId) + } + return None } } @@ -779,6 +801,7 @@ private[spark] class BlockManager( } var blockWasSuccessfullyStored = false + var result: PutResult = null putBlockInfo.synchronized { logTrace("Put for block %s took %s to get into synchronized block" @@ -803,11 +826,9 @@ private[spark] class BlockManager( } // Actually put the values - val result = data match { + result = data match { case IteratorValues(iterator) => - blockStore.putIterator(blockId, iterator, putLevel, returnValues) - case ArrayValues(array) => - blockStore.putArray(blockId, array, putLevel, returnValues) + blockStore.putIterator(blockId, iterator(), putLevel, returnValues) case ByteBufferValues(bytes) => bytes.rewind() blockStore.putBytes(blockId, bytes, putLevel) @@ -834,7 +855,11 @@ private[spark] class BlockManager( } } finally { if (blockWasSuccessfullyStored) { - blockInfoManager.downgradeLock(blockId) + if (keepReadLock) { + blockInfoManager.downgradeLock(blockId) + } else { + blockInfoManager.unlock(blockId) + } } else { blockInfoManager.removeBlock(blockId) logWarning(s"Putting block $blockId failed") @@ -852,18 +877,20 @@ private[spark] class BlockManager( Await.ready(replicationFuture, Duration.Inf) } case _ => - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + if (blockWasSuccessfullyStored) { + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + replicate(blockId, bytesAfterPut, putLevel) + logDebug("Put block %s remotely took %s" + .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) } - replicate(blockId, bytesAfterPut, putLevel) - logDebug("Put block %s remotely took %s" - .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) } } @@ -877,7 +904,11 @@ private[spark] class BlockManager( .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } - blockWasSuccessfullyStored + if (blockWasSuccessfullyStored) { + None + } else { + Some(result) + } } /** @@ -1033,7 +1064,7 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.putArray(blockId, elements, level, returnValues = false) + diskStore.putIterator(blockId, elements.toIterator, level, returnValues = false) case Right(bytes) => diskStore.putBytes(blockId, bytes, level) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 6f6a6773ba..d3af50d974 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -19,8 +19,6 @@ package org.apache.spark.storage import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging /** @@ -43,12 +41,6 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends level: StorageLevel, returnValues: Boolean): PutResult - def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult - /** * Return the size of a block in bytes. */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f3f193f2f..bfa6560a72 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -58,14 +58,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc PutResult(bytes.limit(), Right(bytes.duplicate())) } - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) - } - override def putIterator( blockId: BlockId, values: Iterator[Any], diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 2f16c8f3d8..317d73abba 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -120,22 +120,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo PutResult(size, data) } - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - if (level.deserialized) { - val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, () => values, sizeEstimate, deserialized = true) - PutResult(sizeEstimate, Left(values.iterator)) - } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate())) - } - } - override def putIterator( blockId: BlockId, values: Iterator[Any], @@ -166,7 +150,17 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo unrolledValues match { case Left(arrayValues) => // Values are fully unrolled in memory, so store them as an array - val res = putArray(blockId, arrayValues, level, returnValues) + val res = { + if (level.deserialized) { + val sizeEstimate = SizeEstimator.estimate(arrayValues.asInstanceOf[AnyRef]) + tryToPut(blockId, () => arrayValues, sizeEstimate, deserialized = true) + PutResult(sizeEstimate, Left(arrayValues.iterator)) + } else { + val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator) + tryToPut(blockId, () => bytes, bytes.limit, deserialized = false) + PutResult(bytes.limit(), Right(bytes.duplicate())) + } + } PutResult(res.size, res.data) case Right(iteratorValues) => // Not enough space to unroll this block; drop to disk if applicable |