diff options
author | Reynold Xin <rxin@cs.berkeley.edu> | 2013-04-29 15:44:18 -0700 |
---|---|---|
committer | Reynold Xin <rxin@cs.berkeley.edu> | 2013-04-29 15:44:18 -0700 |
commit | d3586ef43870334ee62a683d3ae090bef782615f (patch) | |
tree | b43276f1455690482aec56fa86e4b51e19feff0a /core | |
parent | 0f45347c7b7243dbf54569f057a3605f96d614af (diff) | |
parent | ba6ffa6a5f39765e1652735d1b16b54c2fc78674 (diff) | |
download | spark-d3586ef43870334ee62a683d3ae090bef782615f.tar.gz spark-d3586ef43870334ee62a683d3ae090bef782615f.tar.bz2 spark-d3586ef43870334ee62a683d3ae090bef782615f.zip |
Merge branch 'blockmanager' of github.com:rxin/spark into blockmanager
Conflicts:
core/src/main/scala/spark/storage/DiskStore.scala
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/BlockStoreShuffleFetcher.scala | 13 | ||||
-rw-r--r-- | core/src/main/scala/spark/Dependency.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/spark/PairRDDFunctions.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/spark/ShuffleFetcher.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkEnv.scala | 16 | ||||
-rw-r--r-- | core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 | ||||
-rw-r--r-- | core/src/main/scala/spark/rdd/ShuffledRDD.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/spark/rdd/SubtractedRDD.scala | 18 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 36 | ||||
-rw-r--r-- | core/src/main/scala/spark/serializer/Serializer.scala | 50 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockException.scala | 5 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManager.scala | 73 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManagerWorker.scala | 18 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockObjectWriter.scala | 27 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/DiskStore.scala | 59 |
15 files changed, 269 insertions, 92 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index c27ed36406..2156efbd45 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -1,14 +1,19 @@ package spark -import executor.{ShuffleReadMetrics, TaskMetrics} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import spark.executor.{ShuffleReadMetrics, TaskMetrics} +import spark.serializer.Serializer import spark.storage.{DelegateBlockFetchTracker, BlockManagerId} -import util.{CompletionIterator, TimedIterator} +import spark.util.{CompletionIterator, TimedIterator} + private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { + + override def fetch[K, V]( + shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = { + logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -48,7 +53,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker itr.setDelegate(blockFetcherItr) CompletionIterator[(K,V), Iterator[(K,V)]](itr, { diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 5eea907322..2af44aa383 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * @param shuffleId the shuffle id * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output + * @param serializerClass class name of the serializer to use */ class ShuffleDependency[K, V]( @transient rdd: RDD[(K, V)], - val partitioner: Partitioner) + val partitioner: Partitioner, + val serializerClass: String = null) extends Dependency(rdd) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 67fd1c1a8f..2b0e697337 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,7 +52,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, - mapSideCombine: Boolean = true): RDD[(K, C)] = { + mapSideCombine: Boolean = true, + serializerClass: String = null): RDD[(K, C)] = { if (getKeyClass().isArray) { if (mapSideCombine) { throw new SparkException("Cannot use map-side combining with array keys.") @@ -67,13 +68,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) - val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) + val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V](self, partitioner) + val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) values.mapPartitions(aggregator.combineValuesByKey(_), true) } } @@ -469,7 +470,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Return an RDD with the pairs from `this` whose keys are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -645,7 +646,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * Return an RDD with the keys of each tuple. */ def keys: RDD[K] = self.map(_._1) - + /** * Return an RDD with the values of each tuple. */ diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index 442e9f0269..49addc0c10 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -1,13 +1,16 @@ package spark -import executor.TaskMetrics +import spark.executor.TaskMetrics +import spark.serializer.Serializer + private[spark] abstract class ShuffleFetcher { /** * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)] + def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + serializer: Serializer = Serializer.default): Iterator[(K,V)] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index ffb40bab3a..8ba52245fa 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -3,13 +3,14 @@ package spark import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider -import serializer.Serializer import spark.broadcast.BroadcastManager import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager +import spark.serializer.Serializer import spark.util.AkkaUtils + /** * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently @@ -91,8 +92,12 @@ object SparkEnv extends Logging { Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - + val serializer = Serializer.setDefault( + System.getProperty("spark.serializer", "spark.JavaSerializer")) + + val closureSerializer = Serializer.get( + System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) + def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { logInfo("Registering " + name) @@ -116,9 +121,6 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver) - val closureSerializer = instantiateClass[Serializer]( - "spark.closure.serializer", "spark.JavaSerializer") - val cacheManager = new CacheManager(blockManager) // Have to assign trackerActor after initialization as MapOutputTrackerActor @@ -164,5 +166,5 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir) } - + } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a6235491ca..9e996e9958 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} +import spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -54,7 +55,8 @@ private[spark] class CoGroupAggregator class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true) + val mapSideCombine: Boolean = true, + val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { private val aggr = new CoGroupAggregator @@ -68,9 +70,9 @@ class CoGroupedRDD[K]( logInfo("Adding shuffle dependency with " + rdd) if (mapSideCombine) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) + new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass) } else { - new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part) + new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass) } } } @@ -112,6 +114,7 @@ class CoGroupedRDD[K]( } } + val ser = Serializer.get(serializerClass) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent @@ -124,12 +127,12 @@ class CoGroupedRDD[K]( val fetcher = SparkEnv.get.shuffleFetcher if (mapSideCombine) { // With map side combine on, for each key, the shuffle fetcher returns a list of values. - fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, values) => getSeq(key)(depNum) ++= values } } else { // With map side combine off, for each key the shuffle fetcher returns a single value. - fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, value) => getSeq(key)(depNum) += value } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 4e33b7dd5c..8175e23eff 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -2,6 +2,8 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} import spark.SparkContext._ +import spark.serializer.Serializer + private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -12,13 +14,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * The resulting RDD from a shuffle (e.g. repartitioning of data). * @param prev the parent RDD. * @param part the partitioner used to partition the RDD + * @param serializerClass class name of the serializer to use. * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( @transient prev: RDD[(K, V)], - part: Partitioner) - extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { + part: Partitioner, + serializerClass: String = null) + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) { override val partitioner = Some(part) @@ -28,6 +32,7 @@ class ShuffledRDD[K, V]( override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics) + SparkEnv.get.shuffleFetcher.fetch[K, V]( + shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass)) } } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 481e03b349..f60c35c38e 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -11,6 +11,7 @@ import spark.Partition import spark.SparkEnv import spark.ShuffleDependency import spark.OneToOneDependency +import spark.serializer.Serializer /** * An optimized version of cogroup for set difference/subtraction. @@ -31,7 +32,9 @@ import spark.OneToOneDependency private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( @transient var rdd1: RDD[(K, V)], @transient var rdd2: RDD[(K, W)], - part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { + part: Partitioner, + val serializerClass: String = null) + extends RDD[(K, V)](rdd1.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { Seq(rdd1, rdd2).map { rdd => @@ -40,7 +43,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part) + new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass) } } } @@ -65,6 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] + val serializer = Serializer.get(serializerClass) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -77,12 +81,16 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM } } def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { for (t <- rdd.iterator(itsSplit, context)) op(t.asInstanceOf[(K, V)]) - case ShuffleCoGroupSplitDep(shuffleId) => - for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics)) + } + case ShuffleCoGroupSplitDep(shuffleId) => { + val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, + context.taskMetrics, serializer) + for (t <- iter) op(t.asInstanceOf[(K, V)]) + } } // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 7dc6da4573..51ec89eb74 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -13,9 +13,11 @@ import com.ning.compress.lzf.LZFInputStream import com.ning.compress.lzf.LZFOutputStream import spark._ -import executor.ShuffleWriteMetrics +import spark.executor.ShuffleWriteMetrics +import spark.serializer.Serializer import spark.storage._ -import util.{TimeStampedHashMap, MetadataCleaner} +import spark.util.{TimeStampedHashMap, MetadataCleaner} + private[spark] object ShuffleMapTask { @@ -130,27 +132,33 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) try { - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + // Obtain all the block writers for shuffle blocks. + val blockManager = SparkEnv.get.blockManager + val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId => + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId + blockManager.getDiskBlockWriter(blockId, Serializer.get(dep.serializerClass)) + } + + // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets(bucketId) += pair + buckets(bucketId).write(pair) } + // Close the bucket writers and get the sizes of each block. val compressedSizes = new Array[Byte](numOutputSplits) - - var totalBytes = 0l - - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + var i = 0 + var totalBytes = 0L + while (i < numOutputSplits) { + buckets(i).close() + val size = buckets(i).size() totalBytes += size compressedSizes(i) = MapOutputTracker.compressSize(size) + i += 1 } + + // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index aca86ab6f0..77b1a1a434 100644 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -1,10 +1,14 @@ package spark.serializer -import java.nio.ByteBuffer import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + import spark.util.ByteBufferInputStream + /** * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are @@ -14,6 +18,48 @@ trait Serializer { def newInstance(): SerializerInstance } + +/** + * A singleton object that can be used to fetch serializer objects based on the serializer + * class name. If a previous instance of the serializer object has been created, the get + * method returns that instead of creating a new one. + */ +object Serializer { + + private val serializers = new ConcurrentHashMap[String, Serializer] + private var _default: Serializer = _ + + def default = _default + + def setDefault(clsName: String): Serializer = { + _default = get(clsName) + _default + } + + def get(clsName: String): Serializer = { + if (clsName == null) { + default + } else { + var serializer = serializers.get(clsName) + if (serializer != null) { + // If the serializer has been created previously, reuse that. + serializer + } else this.synchronized { + // Otherwise, create a new one. But make sure no other thread has attempted + // to create another new one at the same time. + serializer = serializers.get(clsName) + if (serializer == null) { + val clsLoader = Thread.currentThread.getContextClassLoader + serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + serializers.put(clsName, serializer) + } + serializer + } + } + } +} + + /** * An instance of a serializer, for use by one thread at a time. */ @@ -45,6 +91,7 @@ trait SerializerInstance { } } + /** * A stream for writing serialized objects. */ @@ -61,6 +108,7 @@ trait SerializationStream { } } + /** * A stream for reading serialized objects. */ diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala new file mode 100644 index 0000000000..f275d476df --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockException.scala @@ -0,0 +1,5 @@ +package spark.storage + +private[spark] +case class BlockException(blockId: String, message: String) extends Exception(message) + diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 6e861ac734..9190c96c71 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -25,15 +25,11 @@ import sun.nio.ch.DirectBuffer private[spark] -case class BlockException(blockId: String, message: String, ex: Exception = null) -extends Exception(message) - -private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, - val serializer: Serializer, + val defaultSerializer: Serializer, maxMemory: Long) extends Logging { @@ -95,7 +91,7 @@ class BlockManager( private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: BlockStore = + private[storage] val diskStore: DiskStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) val connectionManager = new ConnectionManager(0) @@ -293,22 +289,23 @@ class BlockManager( } /** + * A short-circuited method to get blocks directly from disk. This is used for getting + * shuffle blocks. It is safe to do so without a lock on block info since disk store + * never deletes (recent) items. + */ + def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + diskStore.getValues(blockId, serializer) match { + case Some(iterator) => Some(iterator) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + + /** * Get block from local block manager. */ def getLocal(blockId: String): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - - // As an optimization for map output fetches, if the block is for a shuffle, return it - // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.startsWith("shuffle_")) { - return diskStore.getValues(blockId) match { - case Some(iterator) => - Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { @@ -496,9 +493,10 @@ class BlockManager( * fashion as they're received. Expects a size in bytes to be provided for each block fetched, * so that we can control the maxMegabytesInFlight for the fetch. */ - def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) + def getMultiple( + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + return new BlockFetcherIterator(this, blocksByAddress, serializer) } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -509,6 +507,21 @@ class BlockManager( } /** + * A short circuited method to get a block writer that can write data directly to disk. + * This is currently used for writing shuffle files out. + */ + def getDiskBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId, serializer) + writer.registerCloseEventHandler(() => { + // TODO(rxin): This doesn't handle error cases. + val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) + blockInfo.put(blockId, myInfo) + myInfo.markReady(writer.size()) + }) + writer + } + + /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, @@ -607,7 +620,6 @@ class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis @@ -885,7 +897,10 @@ class BlockManager( if (shouldCompress(blockId)) new LZFInputStream(s) else s } - def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = { + def dataSerialize( + blockId: String, + values: Iterator[Any], + serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() @@ -897,7 +912,10 @@ class BlockManager( * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = { + def dataDeserialize( + blockId: String, + bytes: ByteBuffer, + serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) serializer.newInstance().deserializeStream(stream).asIterator @@ -951,7 +969,8 @@ object BlockManager extends Logging { class BlockFetcherIterator( private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer ) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { import blockManager._ @@ -1014,8 +1033,8 @@ class BlockFetcherIterator( "Unexpected message " + blockMessage.getType + " received from " + cmId) } val blockId = blockMessage.getId - results.put(new FetchResult( - blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData))) + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) _remoteBytesRead += req.size logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -1079,7 +1098,7 @@ class BlockFetcherIterator( // any memory that might exceed our maxBytesInFlight startTime = System.currentTimeMillis for (id <- localBlockIds) { - getLocal(id) match { + getLocalFromDisk(id, serializer) match { case Some(iter) => { results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight logDebug("Got local block " + id) diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d2985559c1..15225f93a6 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -19,7 +19,7 @@ import spark.network._ */ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { initLogging() - + blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { @@ -51,7 +51,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) return None - } + } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) logDebug("Received [" + gB + "]") @@ -90,28 +90,26 @@ private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 - + initLogging() - + def startBlockManagerWorker(manager: BlockManager) { blockManagerWorker = new BlockManagerWorker(manager) } - + def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val resultMessage = connectionManager.sendMessageReliablySync( toConnManagerId, blockMessageArray.toBufferMessage) return (resultMessage != None) } - + def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val responseMessage = connectionManager.sendMessageReliablySync( diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala new file mode 100644 index 0000000000..657a7e9143 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala @@ -0,0 +1,27 @@ +package spark.storage + +import java.nio.ByteBuffer + + +abstract class BlockObjectWriter(val blockId: String) { + + // TODO(rxin): What if there is an exception when the block is being written out? + + var closeEventHandler: () => Unit = _ + + def registerCloseEventHandler(handler: () => Unit) { + closeEventHandler = handler + } + + def write(value: Any) + + def writeAll(value: Iterator[Any]) { + value.foreach(write) + } + + def close() { + closeEventHandler() + } + + def size(): Long +} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c9553d2e0f..b527a3c708 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -1,18 +1,19 @@ package spark.storage +import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} import java.nio.ByteBuffer -import java.io.{File, FileOutputStream, RandomAccessFile} import java.nio.channels.FileChannel.MapMode import java.util.{Random, Date} import java.text.SimpleDateFormat -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - import scala.collection.mutable.ArrayBuffer -import spark.executor.ExecutorExitCode +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark.Utils +import spark.executor.ExecutorExitCode +import spark.serializer.Serializer + /** * Stores BlockManager blocks on disk. @@ -23,6 +24,34 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private val mapMode = MapMode.READ_ONLY private var mapOpenMode = "r" + class DiskBlockObjectWriter(blockId: String, serializer: Serializer) + extends BlockObjectWriter(blockId) { + + private val f: File = createFile(blockId /*, allowAppendExisting */) + private val bs: OutputStream = blockManager.wrapForCompression(blockId, + new FastBufferedOutputStream(new FileOutputStream(f))) + private val objOut = serializer.newInstance().serializeStream(bs) + + private var _size: Long = -1L + + override def write(value: Any) { + objOut.writeObject(value) + } + + override def close() { + objOut.close() + bs.close() + super.close() + } + + override def size(): Long = { + if (_size < 0) { + _size = f.length() + } + _size + } + } + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -34,6 +63,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() + def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { + new DiskBlockObjectWriter(blockId, serializer) + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -79,12 +112,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) val fileOut = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.serializer.newInstance().serializeStream(fileOut) + val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) objOut.writeAll(values.iterator) objOut.close() val length = file.length() + + val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime))) + blockId, Utils.memoryBytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file @@ -105,6 +140,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { @@ -118,9 +161,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).exists() } - private def createFile(blockId: String): File = { + private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) - if (file.exists()) { + if (!allowAppendExisting && file.exists()) { throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file |