aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-04-24 14:52:49 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-04-24 14:52:49 -0700
commitaa618ed2a2df209da3f93a025928366959c37d04 (patch)
tree3063eaa58f1f42d4ef0cd012f7cc201dbd6f5096
parent31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d (diff)
downloadspark-aa618ed2a2df209da3f93a025928366959c37d04.tar.gz
spark-aa618ed2a2df209da3f93a025928366959c37d04.tar.bz2
spark-aa618ed2a2df209da3f93a025928366959c37d04.zip
Allow changing the serializer on a per shuffle basis.
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala13
-rw-r--r--core/src/main/scala/spark/Dependency.scala4
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala11
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala7
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala20
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala8
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala50
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala42
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala18
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala18
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala2
14 files changed, 153 insertions, 82 deletions
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index c27ed36406..2156efbd45 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,14 +1,19 @@
package spark
-import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import spark.serializer.Serializer
import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
-import util.{CompletionIterator, TimedIterator}
+import spark.util.{CompletionIterator, TimedIterator}
+
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
+
+ override def fetch[K, V](
+ shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -48,7 +53,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
itr.setDelegate(blockFetcherItr)
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 5eea907322..2af44aa383 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val partitioner: Partitioner)
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 07efba9e8d..1b9b9d21d8 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,7 +52,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
- mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@@ -67,13 +68,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner)
+ val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@@ -469,7 +470,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -644,7 +645,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
-
+
/**
* Return an RDD with the values of each tuple.
*/
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 442e9f0269..49addc0c10 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,13 +1,16 @@
package spark
-import executor.TaskMetrics
+import spark.executor.TaskMetrics
+import spark.serializer.Serializer
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = Serializer.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index c10bedb8f6..8a751fbd6e 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -3,13 +3,14 @@ package spark
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
-import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
+import spark.serializer.Serializer
import spark.util.AkkaUtils
+
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
@@ -22,7 +23,6 @@ class SparkEnv (
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
- val shuffleSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
@@ -82,7 +82,11 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
+ val serializer = Serializer.setDefault(
+ System.getProperty("spark.serializer", "spark.JavaSerializer"))
+
+ val closureSerializer = Serializer.get(
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
@@ -97,17 +101,10 @@ object SparkEnv extends Logging {
}
}
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
- val shuffleSerializer = instantiateClass[Serializer](
- "spark.shuffle.serializer", "spark.JavaSerializer")
-
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new spark.storage.BlockManagerMasterActor(isLocal)))
- val blockManager = new BlockManager(
- executorId, actorSystem, blockManagerMaster, serializer, shuffleSerializer)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
@@ -149,7 +146,6 @@ object SparkEnv extends Logging {
actorSystem,
serializer,
closureSerializer,
- shuffleSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index a6235491ca..9e996e9958 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -54,7 +55,8 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true)
+ val mapSideCombine: Boolean = true,
+ val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@@ -68,9 +70,9 @@ class CoGroupedRDD[K](
logInfo("Adding shuffle dependency with " + rdd)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
- new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
+ new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
@@ -112,6 +114,7 @@ class CoGroupedRDD[K](
}
}
+ val ser = Serializer.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@@ -124,12 +127,12 @@ class CoGroupedRDD[K](
val fetcher = SparkEnv.get.shuffleFetcher
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 4e33b7dd5c..8175e23eff 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -2,6 +2,8 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
+import spark.serializer.Serializer
+
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -12,13 +14,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
+ * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
@transient prev: RDD[(K, V)],
- part: Partitioner)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
+ part: Partitioner,
+ serializerClass: String = null)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
override val partitioner = Some(part)
@@ -28,6 +32,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](
+ shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 481e03b349..f60c35c38e 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -11,6 +11,7 @@ import spark.Partition
import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+import spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -31,7 +32,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -40,7 +43,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -65,6 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = Serializer.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -77,12 +81,16 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 97b668cd58..d9b26c9db9 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,9 +13,11 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
-import executor.ShuffleWriteMetrics
+import spark.executor.ShuffleWriteMetrics
+import spark.serializer.Serializer
import spark.storage._
-import util.{TimeStampedHashMap, MetadataCleaner}
+import spark.util.{TimeStampedHashMap, MetadataCleaner}
+
private[spark] object ShuffleMapTask {
@@ -126,7 +128,7 @@ private[spark] class ShuffleMapTask(
val blockManager = SparkEnv.get.blockManager
val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId =>
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId
- blockManager.getBlockWriter(blockId)
+ blockManager.getBlockWriter(blockId, Serializer.get(dep.serializerClass))
}
// Write the map output to its associated buckets.
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index aca86ab6f0..77b1a1a434 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,10 +1,14 @@
package spark.serializer
-import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.util.concurrent.ConcurrentHashMap
+
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
import spark.util.ByteBufferInputStream
+
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
@@ -14,6 +18,48 @@ trait Serializer {
def newInstance(): SerializerInstance
}
+
+/**
+ * A singleton object that can be used to fetch serializer objects based on the serializer
+ * class name. If a previous instance of the serializer object has been created, the get
+ * method returns that instead of creating a new one.
+ */
+object Serializer {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
+
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
@@ -45,6 +91,7 @@ trait SerializerInstance {
}
}
+
/**
* A stream for writing serialized objects.
*/
@@ -61,6 +108,7 @@ trait SerializationStream {
}
}
+
/**
* A stream for reading serialized objects.
*/
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 2f97bad916..9f7985e2e8 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -29,8 +29,7 @@ class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
- val serializer: Serializer,
- val shuffleSerializer: Serializer,
+ val defaultSerializer: Serializer,
maxMemory: Long)
extends Logging {
@@ -123,17 +122,8 @@ class BlockManager(
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, shuffleSerializer: Serializer) = {
- this(execId, actorSystem, master, serializer, shuffleSerializer,
- BlockManager.getMaxMemoryFromSystemProperties)
- }
-
- /**
- * Construct a BlockManager with a memory limit set based on system properties.
- */
- def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, maxMemory: Long) = {
- this(execId, actorSystem, master, serializer, serializer, maxMemory)
+ serializer: Serializer) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
}
/**
@@ -479,9 +469,10 @@ class BlockManager(
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress)
+ return new BlockFetcherIterator(this, blocksByAddress, serializer)
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -495,8 +486,8 @@ class BlockManager(
* A short circuited method to get a block writer that can write data directly to disk.
* This is currently used for writing shuffle files out.
*/
- def getBlockWriter(blockId: String): BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId)
+ def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer)
writer.registerCloseEventHandler(() => {
// TODO(rxin): This doesn't handle error cases.
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
@@ -850,7 +841,10 @@ class BlockManager(
if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
@@ -862,7 +856,10 @@ class BlockManager(
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
serializer.newInstance().deserializeStream(stream).asIterator
@@ -916,7 +913,8 @@ object BlockManager extends Logging {
class BlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer
) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
import blockManager._
@@ -979,8 +977,8 @@ class BlockFetcherIterator(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += req.size
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index d2985559c1..15225f93a6 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -19,7 +19,7 @@ import spark.network._
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
-
+
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
@@ -51,7 +51,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
logDebug("Received [" + gB + "]")
@@ -90,28 +90,26 @@ private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
+
initLogging()
-
+
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
-
+
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
return (resultMessage != None)
}
-
+
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 493936fdbe..70ad887c3b 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,17 +1,18 @@
package spark.storage
-import java.nio.ByteBuffer
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import scala.collection.mutable.ArrayBuffer
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
import spark.Utils
import spark.executor.ExecutorExitCode
+import spark.serializer.Serializer
/**
@@ -20,12 +21,13 @@ import spark.executor.ExecutorExitCode
private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) {
- class DiskBlockObjectWriter(blockId: String) extends BlockObjectWriter(blockId) {
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer)
+ extends BlockObjectWriter(blockId) {
private val f: File = createFile(blockId /*, allowAppendExisting */)
private val bs: OutputStream = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(f)))
- private val objOut = blockManager.shuffleSerializer.newInstance().serializeStream(bs)
+ private val objOut = serializer.newInstance().serializeStream(bs)
private var _size: Long = -1L
@@ -58,8 +60,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
- def getBlockWriter(blockId: String): BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId)
+ def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer)
}
override def getSize(blockId: String): Long = {
@@ -92,7 +94,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 3875e7459e..5c406e68cb 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -78,7 +78,7 @@ private[spark] object ThreadingTest {
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, serializer, 1024 * 1024)
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)