aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java19
-rwxr-xr-xcore/src/main/java/org/apache/spark/network/netty/PathResolver.java11
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala168
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java10
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java6
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala184
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala301
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala (renamed from core/src/main/scala/spark/storage/StoragePerfTester.scala)10
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala97
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala63
37 files changed, 743 insertions, 569 deletions
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index cfd8132891..172c6e4b1c 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -25,6 +25,7 @@ import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -37,40 +38,34 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
- String path = pResolver.getAbsolutePath(blockId.name());
- // if getFilePath returns null, close the channel
- if (path == null) {
+ FileSegment fileSegment = pResolver.getBlockLocation(blockId);
+ // if getBlockLocation returns null, close the channel
+ if (fileSegment == null) {
//ctx.close();
return;
}
- File file = new File(path);
+ File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
- //logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
- long length = file.length();
+ long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
- //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
- //logger.info("Sending block "+blockId+" filelen = "+len);
- //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
- .getChannel(), 0, file.length()));
+ .getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
- //logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
- //logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
index 94c034cad0..9f7ced44cf 100755
--- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
+++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
@@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1e3f1ebfaf..5e465fa22c 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,13 +20,11 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
-import akka.remote._
import akka.util.Duration
@@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
+ }
+
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 29968c273c..aaab717bcf 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -187,10 +187,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 4830067f7a..3e85052cd0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
-
- public abstract Iterable<Double> call(T t);
-
- @Override
- public final Iterable<Double> apply(T t) { return call(t); }
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
index db34cd190a..5e9b8c48b8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -29,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
-
- public abstract Double call(T t) throws Exception;
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index 158539a846..2dfda8b09a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- @throws(classOf[Exception])
- def call(x: T) : java.lang.Iterable[R]
-
def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 5ef6a814f5..528e1c0a7c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- @throws(classOf[Exception])
- def call(a: A, b:B) : java.lang.Iterable[C]
-
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index b9070cfd83..ce368ee01b 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -30,8 +29,6 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
- public abstract R call(T t) throws Exception;
-
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
index d4c9154869..44ad559d48 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@@ -29,8 +28,6 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
- public abstract R call(T1 t1, T2 t2) throws Exception;
-
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index c0e5544b7d..6d76a8f970 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -34,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
-
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index 40480fe8e8..ede7ceefb5 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -29,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
-public abstract class PairFunction<T, K, V>
- extends WrappedFunction1<T, Tuple2<K, V>>
+public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public abstract Tuple2<K, V> call(T t) throws Exception;
-
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 1586dff254..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return file.getAbsolutePath
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d84f5968df..dda194d953 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -52,13 +52,14 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setDAGScheduler(this)
@@ -182,7 +183,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -195,6 +196,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
+ numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@@ -207,9 +209,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ val stage =
+ new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
- stageToInfos(stage) = StageInfo(stage)
+ stageToInfos(stage) = new StageInfo(stage)
stage
}
@@ -361,7 +364,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- val finalStage = newStage(rdd, None, jobId, Some(callSite))
+ val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -584,7 +587,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
- listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+ listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@@ -605,9 +608,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- if (!stage.submissionTime.isDefined) {
- stage.submissionTime = Some(System.currentTimeMillis())
- }
+ stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -629,12 +630,12 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.submissionTime match {
+ val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
- case _ => "Unkown"
+ case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.completionTime = Some(System.currentTimeMillis)
+ stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@@ -804,7 +805,7 @@ class DAGScheduler(
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
- failedStage.completionTime = Some(System.currentTimeMillis())
+ stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 19c0251690..12b0d74fb5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -207,13 +207,13 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(stageSubmitted.stage.id, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 40baea69e8..24d97da6eb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -167,8 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
@@ -191,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
+ buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 466baf9913..324cd639b0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -24,10 +24,10 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
-case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
@@ -80,7 +80,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@@ -93,7 +93,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
- val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@@ -111,7 +111,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(stage.stageInfo.taskInfos.flatMap{
+ Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index aa293dc6b3..7cb3fe46e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
+ val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
-
- /** When first task was submitted to scheduler. */
- var submissionTime: Option[Long] = None
- var completionTime: Option[Long] = None
-
private var nextAttemptId = 0
def isAvailable: Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index b6f11969e5..37bb0891f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,9 +21,18 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
-case class StageInfo(
- val stage: Stage,
+class StageInfo(
+ stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
- override def toString = stage.rdd.toString
+ val stageId = stage.id
+ /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+ val rddName = stage.rdd.name
+ val name = stage.name
+ val numPartitions = stage.numPartitions
+ val numTasks = stage.numTasks
+
+ override def toString = rddName
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1fe0d0e4e2..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- def run(attemptId: Long): T = {
+ final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 768e5a647f..e6329cbd47 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -102,18 +102,19 @@ private[spark] class BlockManager(
}
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -512,16 +513,20 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
+ val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
+ diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
+ myInfo.markReady(writer.fileSegment().length)
})
writer
}
@@ -862,13 +867,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 76c92cefd8..32d2dd0694 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -59,12 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
*/
- def size(): Long
+ def fileSegment(): FileSegment
/**
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private var initialPosition = 0L
+ private var lastValidPosition = 0L
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ initialPosition = channel.position
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..bcb58ad946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ // Stores only Blocks which have been specifically mapped to segments of files
+ // (rather than the default, which maps a Block to a whole file).
+ // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
+ private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
+
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
+
+ addShutdownHook()
+
+ /**
+ * Creates a logical mapping from the given BlockId to a segment of a file.
+ * This will cause any accesses of the logical BlockId to be directed to the specified
+ * physical location.
+ */
+ def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
+ blockToFileSegmentMap.put(blockId, fileSegment)
+ }
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
+ blockToFileSegmentMap.get(blockId).get
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ /**
+ * Simply returns a File to place the given Block into. This does not physically create the file.
+ * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
+ * a unique filename.
+ */
+ def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
+ val actualFilename = if (filename == "") blockId.name else filename
+ val file = getFile(actualFilename)
+ if (!allowAppending && file.exists()) {
+ throw new IllegalStateException(
+ "Attempted to create file that already exists: " + actualFilename)
+ }
+ file
+ }
+
+ private def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ blockToFileSegmentMap.clearOldValues(cleanupTime)
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 2a9a3f61bd..a3c496f9e0 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,158 +17,25 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
- private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
- def timeWriting = _timeWriting
- private var _timeWriting = 0L
-
- private def callWithTiming(f: => Unit) = {
- val start = System.nanoTime()
- f
- _timeWriting += (System.nanoTime() - start)
- }
-
- def write(i: Int): Unit = callWithTiming(out.write(i))
- override def write(b: Array[Byte]) = callWithTiming(out.write(b))
- override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
- }
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
- private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var fos: FileOutputStream = null
- private var ts: TimeTrackingOutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
- private var _timeWriting = 0L
-
- override def open(): DiskBlockObjectWriter = {
- fos = new FileOutputStream(f, true)
- ts = new TimeTrackingOutputStream(fos)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(ts, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- if (syncWrites) {
- // Force outstanding writes to disk and track how long it takes
- objOut.flush()
- val start = System.nanoTime()
- fos.getFD.sync()
- _timeWriting += System.nanoTime() - start
- objOut.close()
- } else {
- objOut.close()
- }
-
- _timeWriting += ts.timeWriting
-
- channel = null
- bs = null
- fos = null
- ts = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
-
- // Only valid if called after close()
- override def timeWriting = _timeWriting
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
- }
-
override def getSize(blockId: BlockId): Long = {
- getFile(blockId).length()
+ diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@@ -177,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@@ -209,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -231,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@@ -249,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
+ }
false
}
}
override def contains(blockId: BlockId): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: BlockId): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId.name)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
- }
- }
-
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
- if (!blockId.isShuffle) null
- else DiskStore.this.getFile(blockId).getAbsolutePath
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f39fcd87fb..229178c095 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,12 +17,13 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.spark.serializer.Serializer
private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
-
+class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@@ -30,24 +31,61 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
+ * per reducer.
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
+ * it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+
+ var nextFileId = new AtomicInteger(0)
+ val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
}
- new ShuffleWriterGroup(mapId, writers)
+ new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ override def releaseWriters(group: ShuffleWriterGroup) {
+ recycleFileId(group.fileId)
}
}
}
+
+ private def getUnusedFileId(): Int = {
+ val fileId = unusedFileIds.poll()
+ if (fileId == null) nextFileId.getAndIncrement() else fileId
+ }
+
+ private def recycleFileId(fileId: Int) {
+ if (!consolidateShuffleFiles) { return } // ensures we always generate new file id
+ unusedFileIds.add(fileId)
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
}
diff --git a/core/src/main/scala/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 5f30383fd0..7dcadc3805 100644
--- a/core/src/main/scala/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -7,9 +7,11 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils
-/** Utility for micro-benchmarking shuffle write performance.
- *
- * Writes simulated shuffle output from several threads and records the observed throughput*/
+/**
+ * Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput.
+ */
object StoragePerfTester {
def main(args: Array[String]) = {
/** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
@@ -44,7 +46,7 @@ object StoragePerfTester {
}
buckets.writers.map {w =>
w.commit()
- total.addAndGet(w.size())
+ total.addAndGet(w.fileSegment().length)
w.close()
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index b39c0e9769..ca5a28625b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
- for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eb3b4e8522..9bb8a13ec4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
- val stageToPool = new HashMap[Stage, String]()
- val stageToDescription = new HashMap[Stage, String]()
- val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+ val stageIdToPool = new HashMap[Int, String]()
+ val stageIdToDescription = new HashMap[Int, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
- val activeStages = HashSet[Stage]()
- val completedStages = ListBuffer[Stage]()
- val failedStages = ListBuffer[Stage]()
+ val activeStages = HashSet[StageInfo]()
+ val completedStages = ListBuffer[StageInfo]()
+ val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
- val stageToTime = HashMap[Int, Long]()
- val stageToShuffleRead = HashMap[Int, Long]()
- val stageToShuffleWrite = HashMap[Int, Long]()
- val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
- val stageToTasksComplete = HashMap[Int, Int]()
- val stageToTasksFailed = HashMap[Int, Int]()
- val stageToTaskInfos =
+ val stageIdToTime = HashMap[Int, Long]()
+ val stageIdToShuffleRead = HashMap[Int, Long]()
+ val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageIdToTasksComplete = HashMap[Int, Int]()
+ val stageIdToTasksFailed = HashMap[Int, Int]()
+ val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
- val stage = stageCompleted.stageInfo.stage
- poolToActiveStages(stageToPool(stage)) -= stage
+ val stage = stageCompleted.stage
+ poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
- def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
- stageToTaskInfos.remove(s.id)
- stageToTime.remove(s.id)
- stageToShuffleRead.remove(s.id)
- stageToShuffleWrite.remove(s.id)
- stageToTasksActive.remove(s.id)
- stageToTasksComplete.remove(s.id)
- stageToTasksFailed.remove(s.id)
- stageToPool.remove(s)
- if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ stageIdToTaskInfos.remove(s.stageId)
+ stageIdToTime.remove(s.stageId)
+ stageIdToShuffleRead.remove(s.stageId)
+ stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToTasksActive.remove(s.stageId)
+ stageIdToTasksComplete.remove(s.stageId)
+ stageIdToTasksFailed.remove(s.stageId)
+ stageIdToPool.remove(s.stageId)
+ if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@@ -95,63 +95,63 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- stageToPool(stage) = poolName
+ stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- description.map(d => stageToDescription(stage) = d)
+ description.map(d => stageIdToDescription(stage.stageId) = d)
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
- stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
- stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
- stageToTime.getOrElseUpdate(sid, 0L)
+ stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
- stageToTime(sid) += time
+ stageIdToTime(sid) += time
totalTime += time
- stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
- stageToShuffleRead(sid) += shuffleRead
+ stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
- stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
- stageToShuffleWrite(sid) += shuffleWrite
+ stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -159,10 +159,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
- activeStages -= stage
- poolToActiveStages(stageToPool(stage)) -= stage
- failedStages += stage
- trimIfNecessary(failedStages)
+ /* If two jobs share a stage we could get this failure message twice. So we first
+ * check whether we've already retired this stage. */
+ val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
+ stageInfo.foreach {s =>
+ activeStages -= s
+ poolToActiveStages(stageIdToPool(stage.id)) -= s
+ failedStages += s
+ trimIfNecessary(failedStages)
+ }
case _ =>
}
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 06810d8dbc..cfeeccda41 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.{Schedulable, Stage}
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
- var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+ var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
- private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
- private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index b7c81d091c..0dd57235e3 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
- if (!listener.stageToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
- val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
- val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
- val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
- listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
- {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 07db8622da..2c29ea7d44 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
+import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
-private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
- private def stageRow(s: Stage): Seq[Node] = {
+ private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
- val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+ val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
- val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
- val totalTasks = s.numPartitions
+ val totalTasks = s.numTasks
- val poolName = listener.stageToPool.get(s)
+ val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
- <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
- val description = listener.stageToDescription.get(s)
+ <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.toString}</a>
+ val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
- <td>{s.id}</td>
+ <td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0ce1394c77..3f963727d9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6013320eaa..b7eb268bd5 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
- // As if we had two simulatenous fetch failures
+ // As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- val masterTracker = new MapOutputTracker()
+ val masterTracker = new MapOutputTrackerMaster()
masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2a2f828be6..00f2fdd657 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
- var mapOutputTracker: MapOutputTracker = null
+ var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@@ -99,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTracker()
+ mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index cece60dda7..8406093246 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -58,10 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
- val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
- val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
-
- joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
+ val shuffleMapStage =
+ new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
+ val rootStage =
+ new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
+ val rootStageInfo = new StageInfo(rootStage)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index a549417a47..42ca988f7a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,16 +17,57 @@
package org.apache.spark.scheduler
-import org.scalatest.FunSuite
-import org.apache.spark.{SparkContext, LocalSparkContext}
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.{LocalSparkContext, SparkContext}
import scala.collection.mutable
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext._
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+ with BeforeAndAfter {
+ /** Length of time to wait while draining listener events. */
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ before {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ }
+
+ test("basic creation of StageInfo") {
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ rdd2.setName("Target RDD")
+ rdd2.count
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.rddName should be {"Target RDD"}
+ first.numTasks should be {4}
+ first.numPartitions should be {4}
+ first.submissionTime should be ('defined)
+ first.completionTime should be ('defined)
+ first.taskInfos.length should be {4}
+ }
+
+ test("StageInfo with fewer tasks than partitions") {
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.numTasks should be {2}
+ first.numPartitions should be {4}
+ }
test("local metrics") {
- sc = new SparkContext("local[4]", "test")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -39,7 +80,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
d.count()
- val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
@@ -64,7 +104,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@@ -72,11 +112,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
- if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@@ -93,15 +133,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
assert(m.sum / m.size.toDouble > 0.0, msg)
}
- def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
- val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
- !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
- }
-
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
- stageInfos += stage.stageInfo
+ stageInfos += stage.stage
}
}