aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-04-29 15:44:18 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-04-29 15:44:18 -0700
commitd3586ef43870334ee62a683d3ae090bef782615f (patch)
treeb43276f1455690482aec56fa86e4b51e19feff0a /core
parent0f45347c7b7243dbf54569f057a3605f96d614af (diff)
parentba6ffa6a5f39765e1652735d1b16b54c2fc78674 (diff)
downloadspark-d3586ef43870334ee62a683d3ae090bef782615f.tar.gz
spark-d3586ef43870334ee62a683d3ae090bef782615f.tar.bz2
spark-d3586ef43870334ee62a683d3ae090bef782615f.zip
Merge branch 'blockmanager' of github.com:rxin/spark into blockmanager
Conflicts: core/src/main/scala/spark/storage/DiskStore.scala
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala13
-rw-r--r--core/src/main/scala/spark/Dependency.scala4
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala11
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala7
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala16
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala36
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala50
-rw-r--r--core/src/main/scala/spark/storage/BlockException.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala73
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala18
-rw-r--r--core/src/main/scala/spark/storage/BlockObjectWriter.scala27
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala59
15 files changed, 269 insertions, 92 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index c27ed36406..2156efbd45 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,14 +1,19 @@
package spark
-import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import spark.serializer.Serializer
import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
-import util.{CompletionIterator, TimedIterator}
+import spark.util.{CompletionIterator, TimedIterator}
+
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
+
+ override def fetch[K, V](
+ shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -48,7 +53,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
itr.setDelegate(blockFetcherItr)
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 5eea907322..2af44aa383 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val partitioner: Partitioner)
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 67fd1c1a8f..2b0e697337 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,7 +52,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
- mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@@ -67,13 +68,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner)
+ val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@@ -469,7 +470,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -645,7 +646,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
-
+
/**
* Return an RDD with the values of each tuple.
*/
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 442e9f0269..49addc0c10 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,13 +1,16 @@
package spark
-import executor.TaskMetrics
+import spark.executor.TaskMetrics
+import spark.serializer.Serializer
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = Serializer.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index ffb40bab3a..8ba52245fa 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -3,13 +3,14 @@ package spark
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
-import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
+import spark.serializer.Serializer
import spark.util.AkkaUtils
+
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
@@ -91,8 +92,12 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
+ val serializer = Serializer.setDefault(
+ System.getProperty("spark.serializer", "spark.JavaSerializer"))
+
+ val closureSerializer = Serializer.get(
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
+
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
@@ -116,9 +121,6 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver)
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
@@ -164,5 +166,5 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir)
}
-
+
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index a6235491ca..9e996e9958 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -54,7 +55,8 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true)
+ val mapSideCombine: Boolean = true,
+ val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@@ -68,9 +70,9 @@ class CoGroupedRDD[K](
logInfo("Adding shuffle dependency with " + rdd)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
- new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
+ new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
@@ -112,6 +114,7 @@ class CoGroupedRDD[K](
}
}
+ val ser = Serializer.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@@ -124,12 +127,12 @@ class CoGroupedRDD[K](
val fetcher = SparkEnv.get.shuffleFetcher
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 4e33b7dd5c..8175e23eff 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -2,6 +2,8 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
+import spark.serializer.Serializer
+
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -12,13 +14,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
+ * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
@transient prev: RDD[(K, V)],
- part: Partitioner)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
+ part: Partitioner,
+ serializerClass: String = null)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
override val partitioner = Some(part)
@@ -28,6 +32,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](
+ shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 481e03b349..f60c35c38e 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -11,6 +11,7 @@ import spark.Partition
import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+import spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -31,7 +32,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -40,7 +43,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -65,6 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = Serializer.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -77,12 +81,16 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 7dc6da4573..51ec89eb74 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,9 +13,11 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
-import executor.ShuffleWriteMetrics
+import spark.executor.ShuffleWriteMetrics
+import spark.serializer.Serializer
import spark.storage._
-import util.{TimeStampedHashMap, MetadataCleaner}
+import spark.util.{TimeStampedHashMap, MetadataCleaner}
+
private[spark] object ShuffleMapTask {
@@ -130,27 +132,33 @@ private[spark] class ShuffleMapTask(
val taskContext = new TaskContext(stageId, partition, attemptId)
metrics = Some(taskContext.taskMetrics)
try {
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ // Obtain all the block writers for shuffle blocks.
+ val blockManager = SparkEnv.get.blockManager
+ val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId =>
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId
+ blockManager.getDiskBlockWriter(blockId, Serializer.get(dep.serializerClass))
+ }
+
+ // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
+ buckets(bucketId).write(pair)
}
+ // Close the bucket writers and get the sizes of each block.
val compressedSizes = new Array[Byte](numOutputSplits)
-
- var totalBytes = 0l
-
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ var i = 0
+ var totalBytes = 0L
+ while (i < numOutputSplits) {
+ buckets(i).close()
+ val size = buckets(i).size()
totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
+ i += 1
}
+
+ // Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index aca86ab6f0..77b1a1a434 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,10 +1,14 @@
package spark.serializer
-import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.util.concurrent.ConcurrentHashMap
+
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
import spark.util.ByteBufferInputStream
+
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
@@ -14,6 +18,48 @@ trait Serializer {
def newInstance(): SerializerInstance
}
+
+/**
+ * A singleton object that can be used to fetch serializer objects based on the serializer
+ * class name. If a previous instance of the serializer object has been created, the get
+ * method returns that instead of creating a new one.
+ */
+object Serializer {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
+
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
@@ -45,6 +91,7 @@ trait SerializerInstance {
}
}
+
/**
* A stream for writing serialized objects.
*/
@@ -61,6 +108,7 @@ trait SerializationStream {
}
}
+
/**
* A stream for reading serialized objects.
*/
diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..f275d476df
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockException.scala
@@ -0,0 +1,5 @@
+package spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 6e861ac734..9190c96c71 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -25,15 +25,11 @@ import sun.nio.ch.DirectBuffer
private[spark]
-case class BlockException(blockId: String, message: String, ex: Exception = null)
-extends Exception(message)
-
-private[spark]
class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
- val serializer: Serializer,
+ val defaultSerializer: Serializer,
maxMemory: Long)
extends Logging {
@@ -95,7 +91,7 @@ class BlockManager(
private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: BlockStore =
+ private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
val connectionManager = new ConnectionManager(0)
@@ -293,22 +289,23 @@ class BlockManager(
}
/**
+ * A short-circuited method to get blocks directly from disk. This is used for getting
+ * shuffle blocks. It is safe to do so without a lock on block info since disk store
+ * never deletes (recent) items.
+ */
+ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ diskStore.getValues(blockId, serializer) match {
+ case Some(iterator) => Some(iterator)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
+ /**
* Get block from local block manager.
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
-
- // As an optimization for map output fetches, if the block is for a shuffle, return it
- // without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
- return diskStore.getValues(blockId) match {
- case Some(iterator) =>
- Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
-
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
@@ -496,9 +493,10 @@ class BlockManager(
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress)
+ return new BlockFetcherIterator(this, blocksByAddress, serializer)
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -509,6 +507,21 @@ class BlockManager(
}
/**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out.
+ */
+ def getDiskBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer)
+ writer.registerCloseEventHandler(() => {
+ // TODO(rxin): This doesn't handle error cases.
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
@@ -607,7 +620,6 @@ class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
@@ -885,7 +897,10 @@ class BlockManager(
if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
@@ -897,7 +912,10 @@ class BlockManager(
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
serializer.newInstance().deserializeStream(stream).asIterator
@@ -951,7 +969,8 @@ object BlockManager extends Logging {
class BlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer
) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
import blockManager._
@@ -1014,8 +1033,8 @@ class BlockFetcherIterator(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += req.size
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
@@ -1079,7 +1098,7 @@ class BlockFetcherIterator(
// any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
for (id <- localBlockIds) {
- getLocal(id) match {
+ getLocalFromDisk(id, serializer) match {
case Some(iter) => {
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index d2985559c1..15225f93a6 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -19,7 +19,7 @@ import spark.network._
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
-
+
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
@@ -51,7 +51,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
logDebug("Received [" + gB + "]")
@@ -90,28 +90,26 @@ private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
+
initLogging()
-
+
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
-
+
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
return (resultMessage != None)
}
-
+
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..657a7e9143
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,27 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+
+abstract class BlockObjectWriter(val blockId: String) {
+
+ // TODO(rxin): What if there is an exception when the block is being written out?
+
+ var closeEventHandler: () => Unit = _
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ def write(value: Any)
+
+ def writeAll(value: Iterator[Any]) {
+ value.foreach(write)
+ }
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def size(): Long
+}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index c9553d2e0f..b527a3c708 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,18 +1,19 @@
package spark.storage
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.io.{File, FileOutputStream, RandomAccessFile}
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import scala.collection.mutable.ArrayBuffer
-import spark.executor.ExecutorExitCode
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
+import spark.executor.ExecutorExitCode
+import spark.serializer.Serializer
+
/**
* Stores BlockManager blocks on disk.
@@ -23,6 +24,34 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private val mapMode = MapMode.READ_ONLY
private var mapOpenMode = "r"
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer)
+ extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+ private val bs: OutputStream = blockManager.wrapForCompression(blockId,
+ new FastBufferedOutputStream(new FileOutputStream(f)))
+ private val objOut = serializer.newInstance().serializeStream(bs)
+
+ private var _size: Long = -1L
+
+ override def write(value: Any) {
+ objOut.writeObject(value)
+ }
+
+ override def close() {
+ objOut.close()
+ bs.close()
+ super.close()
+ }
+
+ override def size(): Long = {
+ if (_size < 0) {
+ _size = f.length()
+ }
+ _size
+ }
+ }
+
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -34,6 +63,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
+ def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer)
+ }
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@@ -79,12 +112,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+ blockId, Utils.memoryBytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
@@ -105,6 +140,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
+ /**
+ * A version of getValues that allows a custom serializer. This is used as part of the
+ * shuffle short-circuit code.
+ */
+ def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ }
+
override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
@@ -118,9 +161,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).exists()
}
- private def createFile(blockId: String): File = {
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
- if (file.exists()) {
+ if (!allowAppendExisting && file.exists()) {
throw new Exception("File for block " + blockId + " already exists on disk: " + file)
}
file