aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-04-23 17:48:59 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-04-23 17:48:59 -0700
commit31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d (patch)
treea954356b53e837f9aa166dd2eaecbf8785be7861 /core/src
parent17e076de800ea0d4c55f2bd657348641f6f9c55b (diff)
downloadspark-31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d.tar.gz
spark-31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d.tar.bz2
spark-31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d.zip
Added a BlockObjectWriter interface in block manager so ShuffleMapTask
doesn't need to build up an array buffer for each shuffle bucket.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala30
-rw-r--r--core/src/main/scala/spark/storage/BlockException.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala36
-rw-r--r--core/src/main/scala/spark/storage/BlockObjectWriter.scala27
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala43
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala2
7 files changed, 129 insertions, 32 deletions
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 7157fd2688..c10bedb8f6 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -22,6 +22,7 @@ class SparkEnv (
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
+ val shuffleSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
@@ -82,7 +83,7 @@ object SparkEnv extends Logging {
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
+
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
@@ -96,18 +97,22 @@ object SparkEnv extends Logging {
}
}
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "spark.JavaSerializer")
+
+ val shuffleSerializer = instantiateClass[Serializer](
+ "spark.shuffle.serializer", "spark.JavaSerializer")
+
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new spark.storage.BlockManagerMasterActor(isLocal)))
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
+ val blockManager = new BlockManager(
+ executorId, actorSystem, blockManagerMaster, serializer, shuffleSerializer)
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isDriver)
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
@@ -144,6 +149,7 @@ object SparkEnv extends Logging {
actorSystem,
serializer,
closureSerializer,
+ shuffleSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
@@ -153,5 +159,5 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir)
}
-
+
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 36d087a4d0..97b668cd58 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -122,27 +122,33 @@ private[spark] class ShuffleMapTask(
val taskContext = new TaskContext(stageId, partition, attemptId)
metrics = Some(taskContext.taskMetrics)
try {
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ // Obtain all the block writers for shuffle blocks.
+ val blockManager = SparkEnv.get.blockManager
+ val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId =>
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId
+ blockManager.getBlockWriter(blockId)
+ }
+
+ // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
+ buckets(bucketId).write(pair)
}
+ // Close the bucket writers and get the sizes of each block.
val compressedSizes = new Array[Byte](numOutputSplits)
-
- var totalBytes = 0l
-
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ var i = 0
+ var totalBytes = 0L
+ while (i < numOutputSplits) {
+ buckets(i).close()
+ val size = buckets(i).size()
totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
+ i += 1
}
+
+ // Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..f275d476df
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockException.scala
@@ -0,0 +1,5 @@
+package spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 210061e972..2f97bad916 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -25,15 +25,12 @@ import sun.nio.ch.DirectBuffer
private[spark]
-case class BlockException(blockId: String, message: String, ex: Exception = null)
-extends Exception(message)
-
-private[spark]
class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val serializer: Serializer,
+ val shuffleSerializer: Serializer,
maxMemory: Long)
extends Logging {
@@ -78,7 +75,7 @@ class BlockManager(
private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: BlockStore =
+ private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
val connectionManager = new ConnectionManager(0)
@@ -126,8 +123,17 @@ class BlockManager(
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
+ serializer: Serializer, shuffleSerializer: Serializer) = {
+ this(execId, actorSystem, master, serializer, shuffleSerializer,
+ BlockManager.getMaxMemoryFromSystemProperties)
+ }
+
+ /**
+ * Construct a BlockManager with a memory limit set based on system properties.
+ */
+ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
+ serializer: Serializer, maxMemory: Long) = {
+ this(execId, actorSystem, master, serializer, serializer, maxMemory)
}
/**
@@ -486,6 +492,21 @@ class BlockManager(
}
/**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out.
+ */
+ def getBlockWriter(blockId: String): BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId)
+ writer.registerCloseEventHandler(() => {
+ // TODO(rxin): This doesn't handle error cases.
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
@@ -574,7 +595,6 @@ class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..657a7e9143
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,27 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+
+abstract class BlockObjectWriter(val blockId: String) {
+
+ // TODO(rxin): What if there is an exception when the block is being written out?
+
+ var closeEventHandler: () => Unit = _
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ def write(value: Any)
+
+ def writeAll(value: Iterator[Any]) {
+ value.foreach(write)
+ }
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def size(): Long
+}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index ddbf8821ad..493936fdbe 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,7 +1,7 @@
package spark.storage
import java.nio.ByteBuffer
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
@@ -10,9 +10,9 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import scala.collection.mutable.ArrayBuffer
+import spark.Utils
import spark.executor.ExecutorExitCode
-import spark.Utils
/**
* Stores BlockManager blocks on disk.
@@ -20,6 +20,33 @@ import spark.Utils
private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) {
+ class DiskBlockObjectWriter(blockId: String) extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+ private val bs: OutputStream = blockManager.wrapForCompression(blockId,
+ new FastBufferedOutputStream(new FileOutputStream(f)))
+ private val objOut = blockManager.shuffleSerializer.newInstance().serializeStream(bs)
+
+ private var _size: Long = -1L
+
+ override def write(value: Any) {
+ objOut.writeObject(value)
+ }
+
+ override def close() {
+ objOut.close()
+ bs.close()
+ super.close()
+ }
+
+ override def size(): Long = {
+ if (_size < 0) {
+ _size = f.length()
+ }
+ _size
+ }
+ }
+
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -31,6 +58,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
+ def getBlockWriter(blockId: String): BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId)
+ }
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@@ -65,8 +96,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+ blockId, Utils.memoryBytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
@@ -106,9 +139,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).exists()
}
- private def createFile(blockId: String): File = {
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
- if (file.exists()) {
+ if (!allowAppendExisting && file.exists()) {
throw new Exception("File for block " + blockId + " already exists on disk: " + file)
}
file
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 5c406e68cb..3875e7459e 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -78,7 +78,7 @@ private[spark] object ThreadingTest {
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
+ "<driver>", actorSystem, blockManagerMaster, serializer, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)