aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-05-03 01:02:16 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-05-03 01:02:16 -0700
commit2bc895a829caa459e032e12e1d117994dd510a5c (patch)
tree4afa03e5bf96a4c3dea0dfe9f882dca04647ecc8
parentdd7bef31472e8c7dedc93bc1519be5900784c736 (diff)
downloadspark-2bc895a829caa459e032e12e1d117994dd510a5c.tar.gz
spark-2bc895a829caa459e032e12e1d117994dd510a5c.tar.bz2
spark-2bc895a829caa459e032e12e1d117994dd510a5c.zip
Updated according to Matei's code review comment.
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala2
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala10
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala3
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala5
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala5
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala42
-rw-r--r--core/src/main/scala/spark/serializer/SerializerManager.scala45
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala36
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala34
10 files changed, 98 insertions, 88 deletions
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 49addc0c10..9513a00126 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -10,7 +10,7 @@ private[spark] abstract class ShuffleFetcher {
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
- serializer: Serializer = Serializer.default): Iterator[(K,V)]
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 8ba52245fa..2fa97cd829 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -7,7 +7,7 @@ import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
-import spark.serializer.Serializer
+import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
@@ -21,6 +21,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
+ val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -92,10 +93,12 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = Serializer.setDefault(
+ val serializerManager = new SerializerManager
+
+ val serializer = serializerManager.setDefault(
System.getProperty("spark.serializer", "spark.JavaSerializer"))
- val closureSerializer = Serializer.get(
+ val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
@@ -155,6 +158,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
+ serializerManager,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 9e996e9958..7599ba1a02 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
-import spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -114,7 +113,7 @@ class CoGroupedRDD[K](
}
}
- val ser = Serializer.get(serializerClass)
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 8175e23eff..c7d1926b83 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -2,7 +2,6 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
-import spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
@@ -32,7 +31,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](
- shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass))
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.serializerManager.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index f60c35c38e..8a9efc5da2 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -11,7 +11,7 @@ import spark.Partition
import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
-import spark.serializer.Serializer
+
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -68,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val serializer = Serializer.get(serializerClass)
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 124d2d7e26..f097213ab5 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -14,7 +14,6 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.executor.ShuffleWriteMetrics
-import spark.serializer.Serializer
import spark.storage._
import spark.util.{TimeStampedHashMap, MetadataCleaner}
@@ -139,12 +138,12 @@ private[spark] class ShuffleMapTask(
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlockManager#Shuffle = null
+ var shuffle: ShuffleBlocks = null
var buckets: ShuffleWriterGroup = null
try {
// Obtain all the block writers for shuffle blocks.
- val ser = Serializer.get(dep.serializerClass)
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
buckets = shuffle.acquireWriters(partition)
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index 77b1a1a434..2ad73b711d 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -2,7 +2,6 @@ package spark.serializer
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
-import java.util.concurrent.ConcurrentHashMap
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
@@ -20,47 +19,6 @@ trait Serializer {
/**
- * A singleton object that can be used to fetch serializer objects based on the serializer
- * class name. If a previous instance of the serializer object has been created, the get
- * method returns that instead of creating a new one.
- */
-object Serializer {
-
- private val serializers = new ConcurrentHashMap[String, Serializer]
- private var _default: Serializer = _
-
- def default = _default
-
- def setDefault(clsName: String): Serializer = {
- _default = get(clsName)
- _default
- }
-
- def get(clsName: String): Serializer = {
- if (clsName == null) {
- default
- } else {
- var serializer = serializers.get(clsName)
- if (serializer != null) {
- // If the serializer has been created previously, reuse that.
- serializer
- } else this.synchronized {
- // Otherwise, create a new one. But make sure no other thread has attempted
- // to create another new one at the same time.
- serializer = serializers.get(clsName)
- if (serializer == null) {
- val clsLoader = Thread.currentThread.getContextClassLoader
- serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
- serializers.put(clsName, serializer)
- }
- serializer
- }
- }
- }
-}
-
-
-/**
* An instance of a serializer, for use by one thread at a time.
*/
trait SerializerInstance {
diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..60b2aac797
--- /dev/null
+++ b/core/src/main/scala/spark/serializer/SerializerManager.scala
@@ -0,0 +1,45 @@
+package spark.serializer
+
+import java.util.concurrent.ConcurrentHashMap
+
+
+/**
+ * A service that returns a serializer object given the serializer's class name. If a previous
+ * instance of the serializer object has been created, the get method returns that instead of
+ * creating a new one.
+ */
+private[spark] class SerializerManager {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer =
+ Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 4cddcc86fc..498bc9eeb6 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -2,6 +2,7 @@ package spark.storage
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
@@ -26,14 +27,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private val f: File = createFile(blockId /*, allowAppendExisting */)
- private var repositionableStream: FastBufferedOutputStream = null
+ // The file channel, used for repositioning / truncating the file.
+ private var channel: FileChannel = null
private var bs: OutputStream = null
private var objOut: SerializationStream = null
- private var validLength = 0L
+ private var lastValidPosition = 0L
override def open(): DiskBlockObjectWriter = {
- repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f))
- bs = blockManager.wrapForCompression(blockId, repositionableStream)
+ val fos = new FileOutputStream(f, true)
+ channel = fos.getChannel()
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
objOut = serializer.newInstance().serializeStream(bs)
this
}
@@ -41,9 +44,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def close() {
objOut.close()
bs.close()
- objOut = null
+ channel = null
bs = null
- repositionableStream = null
+ objOut = null
// Invoke the close callback handler.
super.close()
}
@@ -54,25 +57,23 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Return the number of bytes written for this commit.
override def commit(): Long = {
bs.flush()
- validLength = repositionableStream.position()
- validLength
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
}
override def revertPartialWrites() {
- // Flush the outstanding writes and delete the file.
- objOut.close()
- bs.close()
- objOut = null
- bs = null
- repositionableStream = null
- f.delete()
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ bs.flush()
+ channel.truncate(lastValidPosition)
}
override def write(value: Any) {
objOut.writeObject(value)
}
- override def size(): Long = validLength
+ override def size(): Long = lastValidPosition
}
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@@ -86,7 +87,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
- def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int): BlockObjectWriter = {
+ def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 1903df0817..49eabfb0d2 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -8,26 +8,30 @@ class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
private[spark]
-class ShuffleBlockManager(blockManager: BlockManager) {
+trait ShuffleBlocks {
+ def acquireWriters(mapId: Int): ShuffleWriterGroup
+ def releaseWriters(group: ShuffleWriterGroup)
+}
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = {
- new Shuffle(shuffleId, numBuckets, serializer)
- }
- class Shuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) {
+private[spark]
+class ShuffleBlockManager(blockManager: BlockManager) {
- // Get a group of writers for a map task.
- def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
- val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ new ShuffleBlocks {
+ // Get a group of writers for a map task.
+ override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
+ val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ }
+ new ShuffleWriterGroup(mapId, writers)
}
- new ShuffleWriterGroup(mapId, writers)
- }
- def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ override def releaseWriters(group: ShuffleWriterGroup) = {
+ // Nothing really to release here.
+ }
}
}
}