aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2012-10-08 16:19:13 -0700
committerMosharaf Chowdhury <mosharaf@cs.berkeley.edu>2012-10-08 16:19:13 -0700
commitedc67bfba8b7875bb751f1a8c84af7135a1d3d39 (patch)
tree5525103ebabe637451409a755ae77d13a321f3a4 /core/src/main
parent119e50c7b9e50a388648ca9434ee1ace5c22867c (diff)
parent46c389983656c31b6503662547b65ef6a0ab7fac (diff)
downloadspark-edc67bfba8b7875bb751f1a8c84af7135a1d3d39.tar.gz
spark-edc67bfba8b7875bb751f1a8c84af7135a1d3d39.tar.bz2
spark-edc67bfba8b7875bb751f1a8c84af7135a1d3d39.zip
Merge branch 'dev' into bc-fix-dev
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala12
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala18
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala2
-rw-r--r--core/src/main/scala/spark/Dependency.scala37
-rw-r--r--core/src/main/scala/spark/HadoopWriter.scala2
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala2
-rw-r--r--core/src/main/scala/spark/Logging.scala24
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala150
-rw-r--r--core/src/main/scala/spark/RDD.scala150
-rw-r--r--core/src/main/scala/spark/Serializer.scala6
-rw-r--r--core/src/main/scala/spark/SparkContext.scala148
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala14
-rw-r--r--core/src/main/scala/spark/Utils.scala27
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala19
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala11
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala11
-rw-r--r--core/src/main/scala/spark/package.scala15
-rw-r--r--core/src/main/scala/spark/partial/BoundedDouble.scala1
-rw-r--r--core/src/main/scala/spark/partial/PartialResult.scala2
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala (renamed from core/src/main/scala/spark/BlockRDD.scala)10
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala (renamed from core/src/main/scala/spark/CartesianRDD.scala)10
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala (renamed from core/src/main/scala/spark/CoGroupedRDD.scala)15
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala (renamed from core/src/main/scala/spark/CoalescedRDD.scala)6
-rw-r--r--core/src/main/scala/spark/rdd/DoubleRDDFunctions.scala (renamed from core/src/main/scala/spark/DoubleRDDFunctions.scala)5
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala (renamed from core/src/main/scala/spark/HadoopRDD.scala)8
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala21
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala (renamed from core/src/main/scala/spark/NewHadoopRDD.scala)8
-rw-r--r--core/src/main/scala/spark/rdd/PairRDDFunctions.scala (renamed from core/src/main/scala/spark/PairRDDFunctions.scala)15
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala (renamed from core/src/main/scala/spark/PipedRDD.scala)8
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala (renamed from core/src/main/scala/spark/SampledRDD.scala)6
-rw-r--r--core/src/main/scala/spark/rdd/SequenceFileRDDFunctions.scala (renamed from core/src/main/scala/spark/SequenceFileRDDFunctions.scala)6
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala (renamed from core/src/main/scala/spark/ShuffledRDD.scala)15
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala (renamed from core/src/main/scala/spark/UnionRDD.scala)10
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala19
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala280
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala7
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala16
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala21
-rw-r--r--core/src/main/scala/spark/storage/PutResult.scala9
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala4
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala8
54 files changed, 874 insertions, 417 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index c157cc8feb..62186de80d 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -49,7 +49,16 @@ class Accumulable[R, T] (
else throw new UnsupportedOperationException("Can't read accumulator value in task")
}
- private[spark] def localValue = value_
+ /**
+ * Get the current value of this accumulator from within a task.
+ *
+ * This is NOT the global value of the accumulator. To get the global value after a
+ * completed operation on the dataset, call `value`.
+ *
+ * The typical use of this method is to directly mutate the local value, eg., to add
+ * an element to a Set.
+ */
+ def localValue = value_
def value_= (r: R) {
if (!deserialized) value_ = r
@@ -93,6 +102,7 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}
+private[spark]
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index fb65ba421a..4554db2249 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -17,18 +17,18 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
- val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
- for ((address, index) <- addresses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
+ for (((address, size), index) <- statuses.zipWithIndex) {
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(i => "shuffle_%d_%d_%d".format(shuffleId, i, reduceId)))
+ (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
@@ -43,9 +43,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, reduceId) =>
- val addr = addresses(mapId.toInt)
- throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
+ case regex(shufId, mapId, _) =>
+ val address = statuses(mapId.toInt)._1
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 9a23f9e7cc..d9cbe3730a 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -157,7 +157,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
}
// For BlockManager.scala only
- def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
+ def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index c0ff94acc6..19a51dd5b8 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -1,22 +1,51 @@
package spark
-abstract class Dependency[T](val rdd: RDD[T], val isShuffle: Boolean) extends Serializable
+/**
+ * Base class for dependencies.
+ */
+abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
-abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
+/**
+ * Base class for dependencies where each partition of the parent RDD is used by at most one
+ * partition of the child RDD. Narrow dependencies allow for pipelined execution.
+ */
+abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+ /**
+ * Get the parent partitions for a child partition.
+ * @param outputPartition a partition of the child RDD
+ * @return the partitions of the parent RDD that the child partition depends upon
+ */
def getParents(outputPartition: Int): Seq[Int]
}
+/**
+ * Represents a dependency on the output of a shuffle stage.
+ * @param shuffleId the shuffle id
+ * @param rdd the parent RDD
+ * @param aggregator optional aggregator; this allows for map-side combining
+ * @param partitioner partitioner used to partition the shuffle output
+ */
class ShuffleDependency[K, V, C](
val shuffleId: Int,
@transient rdd: RDD[(K, V)],
- val aggregator: Aggregator[K, V, C],
+ val aggregator: Option[Aggregator[K, V, C]],
val partitioner: Partitioner)
- extends Dependency(rdd, true)
+ extends Dependency(rdd)
+/**
+ * Represents a one-to-one dependency between partitions of the parent and child RDDs.
+ */
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
+/**
+ * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
+ * @param rdd the parent RDD
+ * @param inStart the start of the range in the parent RDD
+ * @param outStart the start of the range in the child RDD
+ * @param length the length of the range
+ */
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala
index 12b6a0954c..ebb51607e6 100644
--- a/core/src/main/scala/spark/HadoopWriter.scala
+++ b/core/src/main/scala/spark/HadoopWriter.scala
@@ -42,7 +42,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl
setConfParams()
val jCtxt = getJobContext()
- getOutputCommitter().setupJob(jCtxt)
+ getOutputCommitter().setupJob(jCtxt)
}
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index 39d554b6a5..863d00eeb5 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -57,6 +57,6 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
}
}
-private[spark] class JavaSerializer extends Serializer {
+class JavaSerializer extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 69935b86de..90bae26202 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -15,7 +15,7 @@ trait Logging {
private var log_ : Logger = null
// Method to get or create the logger for this object
- def log: Logger = {
+ protected def log: Logger = {
if (log_ == null) {
var className = this.getClass.getName
// Ignore trailing $'s in the class names for Scala objects
@@ -28,48 +28,48 @@ trait Logging {
}
// Log methods that take only a String
- def logInfo(msg: => String) {
+ protected def logInfo(msg: => String) {
if (log.isInfoEnabled) log.info(msg)
}
- def logDebug(msg: => String) {
+ protected def logDebug(msg: => String) {
if (log.isDebugEnabled) log.debug(msg)
}
- def logTrace(msg: => String) {
+ protected def logTrace(msg: => String) {
if (log.isTraceEnabled) log.trace(msg)
}
- def logWarning(msg: => String) {
+ protected def logWarning(msg: => String) {
if (log.isWarnEnabled) log.warn(msg)
}
- def logError(msg: => String) {
+ protected def logError(msg: => String) {
if (log.isErrorEnabled) log.error(msg)
}
// Log methods that take Throwables (Exceptions/Errors) too
- def logInfo(msg: => String, throwable: Throwable) {
+ protected def logInfo(msg: => String, throwable: Throwable) {
if (log.isInfoEnabled) log.info(msg, throwable)
}
- def logDebug(msg: => String, throwable: Throwable) {
+ protected def logDebug(msg: => String, throwable: Throwable) {
if (log.isDebugEnabled) log.debug(msg, throwable)
}
- def logTrace(msg: => String, throwable: Throwable) {
+ protected def logTrace(msg: => String, throwable: Throwable) {
if (log.isTraceEnabled) log.trace(msg, throwable)
}
- def logWarning(msg: => String, throwable: Throwable) {
+ protected def logWarning(msg: => String, throwable: Throwable) {
if (log.isWarnEnabled) log.warn(msg, throwable)
}
- def logError(msg: => String, throwable: Throwable) {
+ protected def logError(msg: => String, throwable: Throwable) {
if (log.isErrorEnabled) log.error(msg, throwable)
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
- def initLogging() { log }
+ protected def initLogging() { log }
}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 116d526854..45441aa5e5 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -1,6 +1,6 @@
package spark
-import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream}
+import java.io._
import java.util.concurrent.ConcurrentHashMap
import akka.actor._
@@ -14,16 +14,19 @@ import akka.util.duration._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scheduler.MapStatus
import spark.storage.BlockManagerId
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
private[spark] sealed trait MapOutputTrackerMessage
-private[spark] case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
+private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
+ extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
def receive = {
- case GetMapOutputLocations(shuffleId: Int) =>
- logInfo("Asked to get map output locations for shuffle " + shuffleId)
+ case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
+ logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker =>
@@ -40,16 +43,16 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val timeout = 10.seconds
- var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
- private var generationLock = new java.lang.Object
+ private val generationLock = new java.lang.Object
- // Cache a serialized version of the output locations for each shuffle to send them out faster
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedLocs = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@@ -80,31 +83,34 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (bmAddresses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
+ mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
- def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ var array = mapStatuses.get(shuffleId)
array.synchronized {
- array(mapId) = bmAddress
+ array(mapId) = status
}
}
- def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
- bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
+ def registerMapOutputs(
+ shuffleId: Int,
+ statuses: Array[MapStatus],
+ changeGeneration: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeGeneration) {
incrementGeneration()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ var array = mapStatuses.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) == bmAddress) {
+ if (array(mapId).address == bmAddress) {
array(mapId) = null
}
}
@@ -117,10 +123,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
- // Called on possibly remote nodes to get the server URIs for a given shuffle
- def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
- val locs = bmAddresses.get(shuffleId)
- if (locs == null) {
+ // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
+ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ val statuses = mapStatuses.get(shuffleId)
+ if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
@@ -129,34 +135,38 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
try {
fetching.wait()
} catch {
- case _ =>
+ case e: InterruptedException =>
}
}
- return bmAddresses.get(shuffleId)
+ return mapStatuses.get(shuffleId).map(status =>
+ (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]]
- val fetchedLocs = deserializeLocations(fetchedBytes)
+ val host = System.getProperty("spark.hostname", Utils.localHostName)
+ val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
+ val fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
- bmAddresses.put(shuffleId, fetchedLocs)
+ mapStatuses.put(shuffleId, fetchedStatuses)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
- return fetchedLocs
+ return fetchedStatuses.map(s =>
+ (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
} else {
- return locs
+ return statuses.map(s =>
+ (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
}
}
def stop() {
communicate(StopMapOutputTracker)
- bmAddresses.clear()
+ mapStatuses.clear()
trackerActor = null
}
@@ -182,75 +192,83 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
- var locs: Array[BlockManagerId] = null
+ var statuses: Array[MapStatus] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
- cachedSerializedLocs.clear()
+ cachedSerializedStatuses.clear()
cacheGeneration = generation
}
- cachedSerializedLocs.get(shuffleId) match {
+ cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
- locs = bmAddresses.get(shuffleId)
+ statuses = mapStatuses.get(shuffleId)
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeLocations(locs)
+ val bytes = serializeStatuses(statuses)
+ logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
- cachedSerializedLocs(shuffleId) = bytes
+ cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
}
// Serialize an array of map output locations into an efficient byte format so that we can send
- // it to reduce tasks. We do this by grouping together the locations by block manager ID.
- def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = {
+ // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
+ // generally be pretty compressible because many map outputs will be on the same hostname.
+ def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
- val dataOut = new DataOutputStream(out)
- dataOut.writeInt(locs.length)
- val grouped = locs.zipWithIndex.groupBy(_._1)
- dataOut.writeInt(grouped.size)
- for ((id, pairs) <- grouped if id != null) {
- dataOut.writeUTF(id.ip)
- dataOut.writeInt(id.port)
- dataOut.writeInt(pairs.length)
- for ((_, blockIndex) <- pairs) {
- dataOut.writeInt(blockIndex)
- }
- }
- dataOut.close()
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(statuses)
+ objOut.close()
out.toByteArray
}
- // Opposite of serializeLocations.
- def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = {
- val dataIn = new DataInputStream(new ByteArrayInputStream(bytes))
- val length = dataIn.readInt()
- val array = new Array[BlockManagerId](length)
- val numGroups = dataIn.readInt()
- for (i <- 0 until numGroups) {
- val ip = dataIn.readUTF()
- val port = dataIn.readInt()
- val id = new BlockManagerId(ip, port)
- val numBlocks = dataIn.readInt()
- for (j <- 0 until numBlocks) {
- array(dataIn.readInt()) = id
- }
+ // Opposite of serializeStatuses.
+ def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
+ /**
+ * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
+ * We do this by encoding the log base 1.1 of the size as an integer, which can support
+ * sizes up to 35 GB with at most 10% error.
+ */
+ def compressSize(size: Long): Byte = {
+ if (size <= 1L) {
+ 0
+ } else {
+ math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
+ }
+ }
+
+ /**
+ * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
+ */
+ def decompressSize(compressedSize: Byte): Long = {
+ if (compressedSize == 0) {
+ 1
+ } else {
+ math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
}
- array
}
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 351c3d9d0b..f32ff475da 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -31,56 +31,86 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
+import spark.rdd.BlockRDD
+import spark.rdd.CartesianRDD
+import spark.rdd.FilteredRDD
+import spark.rdd.FlatMappedRDD
+import spark.rdd.GlommedRDD
+import spark.rdd.MappedRDD
+import spark.rdd.MapPartitionsRDD
+import spark.rdd.MapPartitionsWithSplitRDD
+import spark.rdd.PipedRDD
+import spark.rdd.SampledRDD
+import spark.rdd.UnionRDD
import spark.storage.StorageLevel
import SparkContext._
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
- * partitioned collection of elements that can be operated on in parallel.
+ * partitioned collection of elements that can be operated on in parallel. This class contains the
+ * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
+ * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
+ * as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on
+ * RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs
+ * that can be saved as SequenceFiles. These operations are automatically available on any RDD of
+ * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you
+ * `import spark.SparkContext._`.
*
- * Each RDD is characterized by five main properties:
- * - A list of splits (partitions)
- * - A function for computing each split
- * - A list of dependencies on other RDDs
- * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
- * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
- * HDFS)
+ * Internally, each RDD is characterized by five main properties:
*
- * All the scheduling and execution in Spark is done based on these methods, allowing each RDD to
- * implement its own way of computing itself.
+ * - A list of splits (partitions)
+ * - A function for computing each split
+ * - A list of dependencies on other RDDs
+ * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
+ * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
+ * an HDFS file)
*
- * This class also contains transformation methods available on all RDDs (e.g. map and filter). In
- * addition, PairRDDFunctions contains extra methods available on RDDs of key-value pairs, and
- * SequenceFileRDDFunctions contains extra methods for saving RDDs to Hadoop SequenceFiles.
+ * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD
+ * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for
+ * reading data from a new storage system) by overriding these functions. Please refer to the
+ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
+ * on RDD internals.
*/
abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
- // Methods that must be implemented by subclasses
+ // Methods that must be implemented by subclasses:
+
+ /** Set of partitions in this RDD. */
def splits: Array[Split]
+
+ /** Function for computing a given partition. */
def compute(split: Split): Iterator[T]
+
+ /** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
+
+ // Methods available on all RDDs:
- // Record user function generating this RDD
- val origin = Utils.getSparkCallSite
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
- // Optionally overridden by subclasses to specify how they are partitioned
+ /** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- // Optionally overridden by subclasses to specify placement preferences
+ /** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil
+ /** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
- def elementClassManifest: ClassManifest[T] = classManifest[T]
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
- // Get a unique ID for this RDD
+ /** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
- // Change this RDD's storage level
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
@@ -91,12 +121,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
this
}
- // Turn on the default caching level for this RDD
+ /** Persist this RDD with the default storage level (MEMORY_ONLY). */
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
- // Turn on the default caching level for this RDD
+ /** Persist this RDD with the default storage level (MEMORY_ONLY). */
def cache(): RDD[T] = persist()
+ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
@@ -118,7 +149,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
}
- // Read this RDD; will read from cache if applicable, or otherwise compute
+ /**
+ * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
+ * This should ''not'' be called by users directly, but is available for implementors of custom
+ * subclasses of RDD.
+ */
final def iterator(split: Split): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
@@ -175,8 +210,16 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
Utils.randomizeInPlace(samples, rand).take(total)
}
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def ++(other: RDD[T]): RDD[T] = this.union(other)
def glom(): RDD[Array[T]] = new GlommedRDD(this)
@@ -372,7 +415,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
def saveAsObjectFile(path: String) {
- this.glom
+ this.mapPartitions(iter => iter.grouped(10).map(_.toArray))
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
.saveAsSequenceFile(path)
}
@@ -382,60 +425,3 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
}
-
-class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => U)
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).map(f)
-}
-
-class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => TraversableOnce[U])
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).flatMap(f)
-}
-
-class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).filter(f)
-}
-
-class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
-}
-
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: Iterator[T] => Iterator[U])
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = f(prev.iterator(split))
-}
-
-/**
- * A variant of the MapPartitionsRDD that passes the split index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
- */
-class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U])
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = f(split.index, prev.iterator(split))
-}
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index c0e08289d8..d8bcf6326a 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -9,10 +9,10 @@ import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream
/**
- * A serializer. Because some serialization libraries are not thread safe, this class is used to
+ * A serializer. Because some serialization libraries are not thread safe, this class is used to
* create SerializerInstances that do the actual serialization.
*/
-private[spark] trait Serializer {
+trait Serializer {
def newInstance(): SerializerInstance
}
@@ -88,7 +88,7 @@ private[spark] trait DeserializationStream {
}
gotNext = true
}
-
+
override def hasNext: Boolean = {
if (!gotNext) {
getNext()
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 83c1b49203..84fc541f82 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -4,12 +4,11 @@ import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
-import akka.actor.Actor
-import akka.actor.Actor._
-
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.generic.Growable
+import akka.actor.Actor
+import akka.actor.Actor._
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
@@ -27,20 +26,22 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
-
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
-
import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import spark.broadcast._
-
import spark.deploy.LocalSparkCluster
-
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-
+import spark.rdd.DoubleRDDFunctions
+import spark.rdd.HadoopRDD
+import spark.rdd.NewHadoopRDD
+import spark.rdd.OrderedRDDFunctions
+import spark.rdd.PairRDDFunctions
+import spark.rdd.SequenceFileRDDFunctions
+import spark.rdd.UnionRDD
import spark.scheduler.ShuffleMapTask
import spark.scheduler.DAGScheduler
import spark.scheduler.TaskScheduler
@@ -49,14 +50,20 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import spark.storage.BlockManagerMaster
-class SparkContext(
- master: String,
- frameworkName: String,
- val sparkHome: String,
- val jars: Seq[String])
+/**
+ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
+ * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ * @param sparkHome Location where Spark is instaled on cluster nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+class SparkContext(master: String, jobName: String, val sparkHome: String, val jars: Seq[String])
extends Logging {
- def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
+ def this(master: String, jobName: String) = this(master, jobName, null, Nil)
// Ensure logging is initialized before we spawn any threads
initLogging()
@@ -72,7 +79,7 @@ class SparkContext(
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
- val env = SparkEnv.createFromSystemProperties(
+ private[spark] val env = SparkEnv.createFromSystemProperties(
System.getProperty("spark.master.host"),
System.getProperty("spark.master.port").toInt,
true,
@@ -80,8 +87,8 @@ class SparkContext(
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
- val addedFiles = HashMap[String, Long]()
- val addedJars = HashMap[String, Long]()
+ private[spark] val addedFiles = HashMap[String, Long]()
+ private[spark] val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
@@ -109,7 +116,7 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
scheduler.initialize(backend)
scheduler
@@ -128,7 +135,7 @@ class SparkContext(
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -140,9 +147,9 @@ class SparkContext(
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, master, frameworkName)
+ new CoarseMesosSchedulerBackend(scheduler, this, master, jobName)
} else {
- new MesosSchedulerBackend(scheduler, this, master, frameworkName)
+ new MesosSchedulerBackend(scheduler, this, master, jobName)
}
scheduler.initialize(backend)
scheduler
@@ -154,14 +161,20 @@ class SparkContext(
// Methods for creating RDDs
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
}
- def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
+ /** Distribute a local Scala collection to form an RDD. */
+ def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits)
.map(pair => pair._2.toString)
@@ -199,7 +212,11 @@ class SparkContext(
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
- * values and the InputFormat so that users don't need to pass them directly.
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
+ * }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
@@ -211,6 +228,14 @@ class SparkContext(
minSplits)
}
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
+ * }}}
+ */
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
@@ -254,7 +279,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
- /** Get an RDD for a Hadoop SequenceFile with given key and value types */
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@@ -264,12 +289,17 @@ class SparkContext(
hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
}
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
- * WritableConverter.
+ * WritableConverter. For example, to access a SequenceFile where the keys are Text and the
+ * values are IntWritable, you could simply write
+ * {{{
+ * sparkContext.sequenceFile[String, Int](path, ...)
+ * }}}
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
* both subclasses of Writable and types for which we define a converter (e.g. Int to
@@ -310,17 +340,21 @@ class SparkContext(
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
- /** Build the union of a list of RDDs. */
+ /** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
+ /**
+ * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
+ * to using the `+=` method. Only the master can access the accumulator's `value`.
+ */
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
- * Create an accumulable shared variable, with a `+=` method
+ * Create an [[spark.Accumulable]] shared variable, with a `+=` method
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
@@ -338,10 +372,17 @@ class SparkContext(
new Accumulable(initialValue, param)
}
- // Keep around a weak hash map of values to Cached versions?
- def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
+ /**
+ * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
+ * reading it in distributed functions. The variable will be sent to each cluster only once.
+ */
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
- // Adds a file dependency to all Tasks executed in the future.
+ /**
+ * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
def addFile(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
@@ -357,12 +398,20 @@ class SparkContext(
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get donwloaded to
+ * any new nodes.
+ */
def clearFiles() {
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
- // Adds a jar dependency to all Tasks executed in the future.
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
@@ -373,12 +422,16 @@ class SparkContext(
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
def clearJars() {
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
- // Stop the SparkContext
+ /** Shut down the SparkContext. */
def stop() {
dagScheduler.stop()
dagScheduler = null
@@ -393,10 +446,12 @@ class SparkContext(
logInfo("Successfully stopped SparkContext")
}
- // Get Spark's home location from either a value set through the constructor,
- // or the spark.home Java property, or the SPARK_HOME environment variable
- // (in that order of preference). If neither of these is set, return None.
- def getSparkHome(): Option[String] = {
+ /**
+ * Get Spark's home location from either a value set through the constructor,
+ * or the spark.home Java property, or the SPARK_HOME environment variable
+ * (in that order of preference). If neither of these is set, return None.
+ */
+ private[spark] def getSparkHome(): Option[String] = {
if (sparkHome != null) {
Some(sparkHome)
} else if (System.getProperty("spark.home") != null) {
@@ -428,6 +483,10 @@ class SparkContext(
result
}
+ /**
+ * Run a job on a given set of partitions of an RDD, but take a function of type
+ * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ */
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: Iterator[T] => U,
@@ -444,6 +503,9 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
+ /**
+ * Run a job on all partitions in an RDD and return the results in an array.
+ */
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.splits.size, false)
}
@@ -465,17 +527,19 @@ class SparkContext(
result
}
- // Clean a closure to make it ready to serialized and send to tasks
- // (removes unreferenced variables in $outer's, updates REPL variables)
+ /**
+ * Clean a closure to make it ready to serialized and send to tasks
+ * (removes unreferenced variables in $outer's, updates REPL variables)
+ */
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
- // Default level of parallelism to use when not given by user (e.g. for reduce tasks)
+ /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
- // Default min number of splits for Hadoop RDDs when not given by user
+ /** Default min number of splits for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private var nextShuffleId = new AtomicInteger(0)
@@ -486,7 +550,7 @@ class SparkContext(
private var nextRddId = new AtomicInteger(0)
- // Register a new RDD, returning its RDD ID
+ /** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement()
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index f2a52ab356..6a006e0697 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -44,15 +44,13 @@ class SparkEnv (
blockManager.stop()
blockManager.master.stop()
actorSystem.shutdown()
- // Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
- Thread.sleep(100)
+ // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
+ // down, but let's call it anyway in case it gets fixed in a later release
actorSystem.awaitTermination()
- // Akka's awaitTermination doesn't actually wait until the port is unbound, so sleep a bit
- Thread.sleep(100)
}
}
-object SparkEnv {
+object SparkEnv extends Logging {
private val env = new ThreadLocal[SparkEnv]
def set(e: SparkEnv) {
@@ -111,6 +109,12 @@ object SparkEnv {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Warn about deprecated spark.cache.class property
+ if (System.getProperty("spark.cache.class") != null) {
+ logWarning("The spark.cache.class property is no longer being used! Specify storage " +
+ "levels using the RDD.persist() method instead.")
+ }
+
new SparkEnv(
actorSystem,
serializer,
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index a480fe046d..567c4b1475 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -71,7 +71,7 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory after " + maxAttempts +
+ throw new IOException("Failed to create a temp directory after " + maxAttempts +
" attempts!")
}
try {
@@ -122,7 +122,7 @@ private object Utils extends Logging {
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}
-
+
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@@ -140,9 +140,18 @@ private object Utils extends Logging {
case "file" | null =>
// Remove the file if it already exists
targetFile.delete()
- // Symlink the file locally
- logInfo("Symlinking " + url + " to " + targetFile)
- FileUtil.symLink(url, targetFile.toString)
+ // Symlink the file locally.
+ if (uri.isAbsolute) {
+ // url is absolute, i.e. it starts with "file:///". Extract the source
+ // file's absolute path from the url.
+ val sourceFile = new File(uri)
+ logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ } else {
+ // url is not absolute, i.e. itself is the path to the source file.
+ logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(url, targetFile.getAbsolutePath)
+ }
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
@@ -208,7 +217,7 @@ private object Utils extends Logging {
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
-
+
/**
* Returns a standard ThreadFactory except all threads are daemons.
*/
@@ -232,10 +241,10 @@ private object Utils extends Logging {
return threadPool
}
-
+
/**
- * Return the string to tell how long has passed in seconds. The passing parameter should be in
- * millisecond.
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
return " " + (System.currentTimeMillis - startTimeMs) + " ms "
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 84ec386ce4..3c4399493c 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -1,13 +1,5 @@
package spark.api.java
-import spark.SparkContext.rddToPairRDDFunctions
-import spark.api.java.function.{Function2 => JFunction2}
-import spark.api.java.function.{Function => JFunction}
-import spark.partial.BoundedDouble
-import spark.partial.PartialResult
-import spark.storage.StorageLevel
-import spark._
-
import java.util.{List => JList}
import java.util.Comparator
@@ -19,6 +11,17 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.conf.Configuration
+import spark.api.java.function.{Function2 => JFunction2}
+import spark.api.java.function.{Function => JFunction}
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
+import spark.rdd.OrderedRDDFunctions
+import spark.storage.StorageLevel
+import spark.HashPartitioner
+import spark.Partitioner
+import spark.RDD
+import spark.SparkContext.rddToPairRDDFunctions
+
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index b1b72a3a1f..e51b0c5c15 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -4,6 +4,7 @@ import spark.deploy._
import akka.actor._
import akka.pattern.ask
import akka.util.duration._
+import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
@@ -100,9 +101,13 @@ private[spark] class Client(
def stop() {
if (actor != null) {
- val timeout = 1.seconds
- val future = actor.ask(StopClient)(timeout)
- Await.result(future, timeout)
+ try {
+ val timeout = 1.seconds
+ val future = actor.ask(StopClient)(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: AskTimeoutException => // Ignore it, maybe master went away
+ }
actor = null
}
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index dec0df25b4..da39108164 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -113,7 +113,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val selectedKeysCount = selector.select()
if (selectedKeysCount == 0) {
- logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+ logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
if (selectorThread.isInterrupted) {
logInfo("Selector thread was interrupted!")
@@ -167,7 +167,6 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
def removeConnection(connection: Connection) {
- /*logInfo("Removing connection")*/
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -235,7 +234,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -276,15 +275,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
- logWarning("Not calling back as callback is null")
+ logDebug("Not calling back as callback is null")
None
}
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
- logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
- logWarning("Response to " + bufferMessage + " does not have ack id set")
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
diff --git a/core/src/main/scala/spark/package.scala b/core/src/main/scala/spark/package.scala
new file mode 100644
index 0000000000..389ec4da3e
--- /dev/null
+++ b/core/src/main/scala/spark/package.scala
@@ -0,0 +1,15 @@
+/**
+ * Core Spark functionality. [[spark.SparkContext]] serves as the main entry point to Spark, while
+ * [[spark.RDD]] is the data type representing a distributed collection, and provides most
+ * parallel operations.
+ *
+ * In addition, [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value
+ * pairs, such as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations
+ * available only on RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations
+ * available on RDDs that can be saved as SequenceFiles. These operations are automatically
+ * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when
+ * you `import spark.SparkContext._`.
+ */
+package object spark {
+ // For package docs only
+}
diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala
index 8bedd75182..463c33d6e2 100644
--- a/core/src/main/scala/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/spark/partial/BoundedDouble.scala
@@ -3,7 +3,6 @@ package spark.partial
/**
* A Double with error bars on it.
*/
-private[spark]
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
}
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
index beafbf67c3..200ed4ea1e 100644
--- a/core/src/main/scala/spark/partial/PartialResult.scala
+++ b/core/src/main/scala/spark/partial/PartialResult.scala
@@ -1,6 +1,6 @@
package spark.partial
-private[spark] class PartialResult[R](initialVal: R, isFinal: Boolean) {
+class PartialResult[R](initialVal: R, isFinal: Boolean) {
private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
private var failure: Option[Exception] = None
private var completionHandler: Option[R => Unit] = None
diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index faa99fe3e9..cb73976aed 100644
--- a/core/src/main/scala/spark/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,12 +1,18 @@
-package spark
+package spark.rdd
import scala.collection.mutable.HashMap
+import spark.Dependency
+import spark.RDD
+import spark.SparkContext
+import spark.SparkEnv
+import spark.Split
+
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
}
-
+private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc) {
diff --git a/core/src/main/scala/spark/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 83db2d2934..7c354b6b2e 100644
--- a/core/src/main/scala/spark/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,10 +1,16 @@
-package spark
+package spark.rdd
+
+import spark.NarrowDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
override val index: Int = idx
}
+private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
rdd1: RDD[T],
@@ -45,4 +51,4 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
}
)
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index daba719b14..f1defbe492 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,11 +1,22 @@
-package spark
+package spark.rdd
import java.net.URL
import java.io.EOFException
import java.io.ObjectInputStream
+
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import spark.Aggregator
+import spark.Dependency
+import spark.Logging
+import spark.OneToOneDependency
+import spark.Partitioner
+import spark.RDD
+import spark.ShuffleDependency
+import spark.SparkEnv
+import spark.Split
+
private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
@@ -38,7 +49,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
} else {
logInfo("Adding shuffle dependency with " + rdd)
deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
- context.newShuffleId, rdd, aggr, part)
+ context.newShuffleId, rdd, Some(aggr), part)
}
}
deps.toList
diff --git a/core/src/main/scala/spark/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index f1ae346a44..0967f4f5df 100644
--- a/core/src/main/scala/spark/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,4 +1,8 @@
-package spark
+package spark.rdd
+
+import spark.NarrowDependency
+import spark.RDD
+import spark.Split
private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/rdd/DoubleRDDFunctions.scala
index 1fbf66b7de..d232ddeb7c 100644
--- a/core/src/main/scala/spark/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/spark/rdd/DoubleRDDFunctions.scala
@@ -1,10 +1,13 @@
-package spark
+package spark.rdd
import spark.partial.BoundedDouble
import spark.partial.MeanEvaluator
import spark.partial.PartialResult
import spark.partial.SumEvaluator
+import spark.Logging
+import spark.RDD
+import spark.TaskContext
import spark.util.StatCounter
/**
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
new file mode 100644
index 0000000000..dfe9dc73f3
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -0,0 +1,12 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).filter(f)
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
new file mode 100644
index 0000000000..3534dc8057
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: T => TraversableOnce[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).flatMap(f)
+}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
new file mode 100644
index 0000000000..e30564f2da
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -0,0 +1,12 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 6d448116a9..bf29a1f075 100644
--- a/core/src/main/scala/spark/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.EOFException
import java.util.NoSuchElementException
@@ -15,6 +15,12 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
+import spark.Dependency
+import spark.RDD
+import spark.SerializableWritable
+import spark.SparkContext
+import spark.Split
+
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
new file mode 100644
index 0000000000..b2c7a1cb9e
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: Iterator[T] => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(prev.iterator(split))
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
new file mode 100644
index 0000000000..adc541694e
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -0,0 +1,21 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+private[spark]
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(split.index, prev.iterator(split))
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
new file mode 100644
index 0000000000..59bedad8ef
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class MappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: T => U)
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).map(f)
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 9072698357..dcbceab246 100644
--- a/core/src/main/scala/spark/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
@@ -13,6 +13,12 @@ import org.apache.hadoop.mapreduce.TaskAttemptID
import java.util.Date
import java.text.SimpleDateFormat
+import spark.Dependency
+import spark.RDD
+import spark.SerializableWritable
+import spark.SparkContext
+import spark.Split
+
private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/rdd/PairRDDFunctions.scala
index 80d62caf25..2a94ea263a 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/rdd/PairRDDFunctions.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.EOFException
import java.io.ObjectInputStream
@@ -34,9 +34,20 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import spark.SparkContext._
import spark.partial.BoundedDouble
import spark.partial.PartialResult
+import spark.Aggregator
+import spark.HashPartitioner
+import spark.Logging
+import spark.OneToOneDependency
+import spark.Partitioner
+import spark.RangePartitioner
+import spark.RDD
+import spark.SerializableWritable
+import spark.SparkContext._
+import spark.SparkException
+import spark.Split
+import spark.TaskContext
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 3103d7889b..98ea0c92d6 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.PrintWriter
import java.util.StringTokenizer
@@ -8,6 +8,12 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import spark.OneToOneDependency
+import spark.RDD
+import spark.SparkEnv
+import spark.Split
+
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index ac10aed477..87a5268f27 100644
--- a/core/src/main/scala/spark/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,9 +1,13 @@
-package spark
+package spark.rdd
import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
+import spark.RDD
+import spark.OneToOneDependency
+import spark.Split
+
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/rdd/SequenceFileRDDFunctions.scala
index ea7171d3a1..24c731fa92 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/rdd/SequenceFileRDDFunctions.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.EOFException
import java.net.URL
@@ -23,7 +23,9 @@ import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.Text
-import SparkContext._
+import spark.Logging
+import spark.RDD
+import spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 1a9f4cfec3..7577909b83 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,8 +1,15 @@
-package spark
+package spark.rdd
import scala.collection.mutable.ArrayBuffer
import java.util.{HashMap => JHashMap}
+import spark.Aggregator
+import spark.Partitioner
+import spark.RangePartitioner
+import spark.RDD
+import spark.ShuffleDependency
+import spark.SparkEnv
+import spark.Split
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -15,7 +22,7 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
*/
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
- aggregator: Aggregator[K, V, C],
+ aggregator: Option[Aggregator[K, V, C]],
part: Partitioner)
extends RDD[(K, C)](parent.context) {
@@ -41,7 +48,7 @@ class RepartitionShuffledRDD[K, V](
part: Partitioner)
extends ShuffledRDD[K, V, V](
parent,
- Aggregator[K, V, V](null, null, null, false),
+ None,
part) {
override def compute(split: Split): Iterator[(K, V)] = {
@@ -88,7 +95,7 @@ class ShuffledAggregatedRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
- extends ShuffledRDD[K, V, C](parent, aggregator, part) {
+ extends ShuffledRDD[K, V, C](parent, Some(aggregator), part) {
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 3e795ea2a2..f0b9225f7c 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,7 +1,13 @@
-package spark
+package spark.rdd
import scala.collection.mutable.ArrayBuffer
+import spark.Dependency
+import spark.RangeDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
private[spark] class UnionSplit[T: ClassManifest](
idx: Int,
rdd: RDD[T],
@@ -37,7 +43,7 @@ class UnionRDD[T: ClassManifest](
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
- for ((rdd, index) <- rdds.zipWithIndex) {
+ for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 9b666ed181..6f4c6bffd7 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -422,11 +422,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
- val bmAddress = event.result.asInstanceOf[BlockManagerId]
- val host = bmAddress.ip
+ val status = event.result.asInstanceOf[MapStatus]
+ val host = status.address.ip
logInfo("ShuffleMapTask finished with host " + host)
if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
- stage.addOutputLoc(smt.partition, bmAddress)
+ stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
new file mode 100644
index 0000000000..4532d9497f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -0,0 +1,27 @@
+package spark.scheduler
+
+import spark.storage.BlockManagerId
+import java.io.{ObjectOutput, ObjectInput, Externalizable}
+
+/**
+ * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
+ * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
+ * The map output sizes are compressed using MapOutputTracker.compressSize.
+ */
+private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte])
+ extends Externalizable {
+
+ def this() = this(null, null) // For deserialization only
+
+ def writeExternal(out: ObjectOutput) {
+ address.writeExternal(out)
+ out.writeInt(compressedSizes.length)
+ out.write(compressedSizes)
+ }
+
+ def readExternal(in: ObjectInput) {
+ address = new BlockManagerId(in)
+ compressedSizes = new Array[Byte](in.readInt())
+ in.readFully(compressedSizes)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 966a5e173a..86796d3677 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
var dep: ShuffleDependency[_,_,_],
var partition: Int,
@transient var locs: Seq[String])
- extends Task[BlockManagerId](stageId)
+ extends Task[MapStatus](stageId)
with Externalizable
with Logging {
@@ -109,13 +109,13 @@ private[spark] class ShuffleMapTask(
split = in.readObject().asInstanceOf[Split]
}
- override def run(attemptId: Long): BlockManagerId = {
+ override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner
val bucketIterators =
- if (aggregator.mapSideCombine) {
+ if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) {
+ val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]]
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
@@ -141,17 +141,18 @@ private[spark] class ShuffleMapTask(
buckets.map(_.iterator)
}
- val ser = SparkEnv.get.serializer.newInstance()
+ val compressedSizes = new Array[Byte](numOutputSplits)
+
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a scala iterator from java map
+ // Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
- // TODO: This should probably be DISK_ONLY
- blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
}
- return SparkEnv.get.blockManager.blockManagerId
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 803dd1b97d..1149c00a23 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -29,29 +29,29 @@ private[spark] class Stage(
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.splits.size
- val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
+ val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
private var nextAttemptId = 0
def isAvailable: Boolean = {
- if (/*parents.size == 0 &&*/ !isShuffleMap) {
+ if (!isShuffleMap) {
true
} else {
numAvailableOutputs == numPartitions
}
}
- def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ def addOutputLoc(partition: Int, status: MapStatus) {
val prevList = outputLocs(partition)
- outputLocs(partition) = bmAddress :: prevList
+ outputLocs(partition) = status :: prevList
if (prevList == Nil)
numAvailableOutputs += 1
}
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_ == bmAddress)
+ val newList = prevList.filterNot(_.address == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
@@ -62,7 +62,7 @@ private[spark] class Stage(
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.ip == host)
+ val newList = prevList.filterNot(_.address.ip == host)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 0043dbeb10..88cb114544 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -25,7 +25,8 @@ private[spark] class SparkDeploySchedulerBackend(
"SPARK_MEM",
"SPARK_CLASSPATH",
"SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
+ "SPARK_JAVA_OPTS",
+ "SPARK_TESTING"
)
// Memory used by each executor (in megabytes)
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 9bb88ad6a1..cf4aae03a7 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -341,7 +341,7 @@ private[spark] class TaskSetManager(
def error(message: String) {
// Save the error message
- abort("Mesos error: " + message)
+ abort("Error: " + message)
}
def abort(message: String) {
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index 9737c6b63e..e6d8b9d822 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -38,7 +38,8 @@ private[spark] class CoarseMesosSchedulerBackend(
"SPARK_MEM",
"SPARK_CLASSPATH",
"SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
+ "SPARK_JAVA_OPTS",
+ "SPARK_TESTING"
)
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index e85e4ef318..6f01c8c09d 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -34,7 +34,8 @@ private[spark] class MesosSchedulerBackend(
"SPARK_MEM",
"SPARK_CLASSPATH",
"SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
+ "SPARK_JAVA_OPTS",
+ "SPARK_TESTING"
)
// Memory used by each executor (in megabytes)
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 21a2901548..91b7bebfb3 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -20,7 +20,9 @@ import sun.nio.ch.DirectBuffer
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0)
+ def this() = this(null, 0) // For deserialization only
+
+ def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
@@ -61,7 +63,11 @@ private[spark]
class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long)
extends Logging {
- class BlockInfo(val level: StorageLevel, val tellMaster: Boolean, var pending: Boolean = true) {
+ class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ var pending: Boolean = true
+ var size: Long = -1L
+
+ /** Wait for this BlockInfo to be marked as ready (i.e. block is finished writing) */
def waitForReady() {
if (pending) {
synchronized {
@@ -70,8 +76,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
}
- def markReady() {
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
pending = false
+ size = sizeInBytes
synchronized {
this.notifyAll()
}
@@ -96,8 +104,17 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
- val numParallelFetches = BlockManager.getNumParallelFetchesFromSystemProperties
- val compress = System.getProperty("spark.blockManager.compress", "false").toBoolean
+ // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
+ // for receiving shuffle outputs)
+ val maxBytesInFlight =
+ System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+
+ val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean
+ // Whether to compress RDD partitions that are stored serialized
+ val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean
+
+ val host = System.getProperty("spark.hostname", Utils.localHostName())
initialize()
@@ -183,6 +200,18 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
+
+ // As an optimization for map output fetches, if the block is for a shuffle, return it
+ // without acquiring a lock; the disk store never deletes (recent) items so this should work
+ if (blockId.startsWith("shuffle_")) {
+ return diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ Some(iterator)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
locker.getLock(blockId).synchronized {
val info = blockInfo.get(blockId)
if (info != null) {
@@ -208,7 +237,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
diskStore.getValues(blockId) match {
case Some(iterator) =>
// Put the block back in memory before returning it
- memoryStore.putValues(blockId, iterator, level, true) match {
+ memoryStore.putValues(blockId, iterator, level, true).data match {
case Left(iterator2) =>
return Some(iterator2)
case _ =>
@@ -227,7 +256,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
bytes.rewind()
- return Some(dataDeserialize(bytes))
+ return Some(dataDeserialize(blockId, bytes))
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
@@ -253,6 +282,18 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
def getLocalBytes(blockId: String): Option[ByteBuffer] = {
// TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes")
+
+ // As an optimization for map output fetches, if the block is for a shuffle, return it
+ // without acquiring a lock; the disk store never deletes (recent) items so this should work
+ if (blockId.startsWith("shuffle_")) {
+ return diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ Some(bytes)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
locker.getLock(blockId).synchronized {
val info = blockInfo.get(blockId)
if (info != null) {
@@ -318,7 +359,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
if (data != null) {
logDebug("Data is not null: " + data)
- return Some(dataDeserialize(data))
+ return Some(dataDeserialize(blockId, data))
}
logDebug("Data is null")
}
@@ -336,9 +377,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
/**
* Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
* an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
- * fashion as they're received.
+ * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
+ * so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[String])])
+ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
: Iterator[(String, Option[Iterator[Any]])] = {
if (blocksByAddress == null) {
@@ -350,17 +392,37 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
- // A queue to hold our results. Because we want all the deserializing the happen in the
- // caller's thread, this will actually hold functions to produce the Iterator for each block.
- // For local blocks we'll have an iterator already, while for remote ones we'll deserialize.
- val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])]
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ // A queue to hold our results.
+ val results = new LinkedBlockingQueue[FetchResult]
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ val fetchRequests = new Queue[FetchRequest]
- // Bound the number and memory usage of fetched remote blocks.
- val blocksToRequest = new Queue[(BlockManagerId, BlockMessage)]
+ // Current bytes in flight from our requests
+ var bytesInFlight = 0L
- def sendRequest(bmId: BlockManagerId, blockMessages: Seq[BlockMessage]) {
- val cmId = new ConnectionManagerId(bmId.ip, bmId.port)
- val blockMessageArray = new BlockMessageArray(blockMessages)
+ def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+ val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
@@ -372,58 +434,73 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put((blockId, Some(() => dataDeserialize(blockMessage.getData))))
+ results.put(new FetchResult(
+ blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
- for (blockMessage <- blockMessages) {
- results.put((blockMessage.getId, None))
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
}
}
}
}
- // Split local and remote blocks. Remote blocks are further split into ones that will
- // be requested initially and ones that will be added to a queue of blocks to request.
- val initialRequestBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockMessage]]()
- var initialRequests = 0
- val blocksToGetLater = new ArrayBuffer[(BlockManagerId, BlockMessage)]
- for ((address, blockIds) <- Utils.randomize(blocksByAddress)) {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
- localBlockIds ++= blockIds
+ localBlockIds ++= blockInfos.map(_._1)
} else {
- remoteBlockIds ++= blockIds
- for (blockId <- blockIds) {
- val blockMessage = BlockMessage.fromGetBlock(GetBlock(blockId))
- if (initialRequests < numParallelFetches) {
- initialRequestBlocks.getOrElseUpdate(address, new ArrayBuffer[BlockMessage])
- .append(blockMessage)
- initialRequests += 1
- } else {
- blocksToGetLater.append((address, blockMessage))
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
}
}
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
}
}
- // Add the remaining blocks into a queue to pull later in a random order
- blocksToRequest ++= Utils.randomize(blocksToGetLater)
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
- // Send out initial request(s) for 'numParallelFetches' blocks.
- for ((bmId, blockMessages) <- initialRequestBlocks) {
- sendRequest(bmId, blockMessages)
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
}
- logDebug("Started remote gets for " + numParallelFetches + " blocks in " +
- Utils.getUsedTimeMs(startTime) + " ms")
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
- // Get the local blocks while remote blocks are being fetched.
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
for (id <- localBlockIds) {
getLocal(id) match {
- case Some(block) => {
- results.put((id, Some(() => block)))
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
}
case None => {
@@ -441,20 +518,23 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
- val (blockId, functionOption) = results.take()
- if (remoteBlockIds.contains(blockId) && !blocksToRequest.isEmpty) {
- val (bmId, blockMessage) = blocksToRequest.dequeue()
- sendRequest(bmId, Seq(blockMessage))
+ val result = results.take()
+ bytesInFlight -= result.size
+ if (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
}
- (blockId, functionOption.map(_.apply()))
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
}
/**
- * Put a new block of values to the block manager.
+ * Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
+ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true)
+ : Long = {
+
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
@@ -465,9 +545,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.containsKey(blockId)) {
+ val oldBlock = blockInfo.get(blockId)
+ if (oldBlock != null) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
+ oldBlock.waitForReady()
+ return oldBlock.size
}
// Remember the block's storage level so that we can correctly drop it to disk if it needs
@@ -477,14 +559,19 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
blockInfo.put(blockId, myInfo)
val startTimeMs = System.currentTimeMillis
- var bytes: ByteBuffer = null
// If we need to replicate the data, we'll want access to the values, but because our
// put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where
- // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator.
+ // the put serializes data, we'll remember the bytes, above; but for the case where it
+ // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
+ // Ditto for the bytes after the put
+ var bytesAfterPut: ByteBuffer = null
+
+ // Size of the block in bytes (to return to caller)
+ var size = 0L
+
locker.getLock(blockId).synchronized {
logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
@@ -492,22 +579,26 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
// drop it to disk if the memory store can't hold it.
- memoryStore.putValues(blockId, values, level, true) match {
- case Right(newBytes) => bytes = newBytes
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
case Left(newIterator) => valuesAfterPut = newIterator
}
} else {
// Save directly to disk.
val askForBytes = level.replication > 1 // Don't get back the bytes unless we replicate them
- diskStore.putValues(blockId, values, level, askForBytes) match {
- case Right(newBytes) => bytes = newBytes
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
case _ =>
}
}
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
- myInfo.markReady()
+ myInfo.markReady(size)
if (tellMaster) {
reportBlockStatus(blockId)
}
@@ -517,23 +608,25 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// Replicate block if required
if (level.replication > 1) {
// Serialize the block if not already done
- if (bytes == null) {
+ if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
throw new SparkException(
"Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
}
- bytes = dataSerialize(valuesAfterPut)
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
- replicate(blockId, bytes, level)
+ replicate(blockId, bytesAfterPut, level)
}
- BlockManager.dispose(bytes)
+ BlockManager.dispose(bytesAfterPut)
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
+ notifyCacheTracker(blockId)
}
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
+
+ return size
}
@@ -592,7 +685,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
- myInfo.markReady()
+ myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId)
}
@@ -600,7 +693,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
+ notifyCacheTracker(blockId)
}
// If replication had started, then wait for it to finish
@@ -646,13 +739,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
// TODO: This code will be removed when CacheTracker is gone.
- private def notifyTheCacheTracker(key: String) {
+ private def notifyCacheTracker(key: String) {
if (cacheTracker != null) {
val rddInfo = key.split("_")
val rddId: Int = rddInfo(1).toInt
val partition: Int = rddInfo(2).toInt
- val host = System.getProperty("spark.hostname", Utils.localHostName())
- cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, partition, host))
+ cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
}
}
@@ -699,24 +791,36 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
}
+ def shouldCompress(blockId: String): Boolean = {
+ if (blockId.startsWith("shuffle_")) {
+ compressShuffle
+ } else if (blockId.startsWith("broadcast_")) {
+ compressBroadcast
+ } else if (blockId.startsWith("rdd_")) {
+ compressRdds
+ } else {
+ false // Won't happen in a real cluster, but it can in tests
+ }
+ }
+
/**
- * Wrap an output stream for compression if block compression is enabled
+ * Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(s: OutputStream): OutputStream = {
- if (compress) new LZFOutputStream(s) else s
+ def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ if (shouldCompress(blockId)) new LZFOutputStream(s) else s
}
/**
- * Wrap an input stream for compression if block compression is enabled
+ * Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(s: InputStream): InputStream = {
- if (compress) new LZFInputStream(s) else s
+ def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(byteStream)).writeAll(values).close()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
@@ -725,10 +829,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
bytes.rewind()
- val ser = serializer.newInstance()
- ser.deserializeStream(wrapForCompression(new ByteBufferInputStream(bytes, true))).asIterator
+ val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
+ serializer.newInstance().deserializeStream(stream).asIterator
}
def stop() {
@@ -742,10 +846,6 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
private[spark]
object BlockManager extends Logging {
- def getNumParallelFetchesFromSystemProperties: Int = {
- System.getProperty("spark.blockManager.parallelFetches", "4").toInt
- }
-
def getMaxMemoryFromSystemProperties: Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index f72079e267..d2985559c1 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -31,7 +31,6 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
logDebug("Parsed as a block message array")
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
- /*logDebug("Processed block messages")*/
return Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message", e)
@@ -49,13 +48,13 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
blockMessage.getType match {
case BlockMessage.TYPE_PUT_BLOCK => {
val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
- logInfo("Received [" + pB + "]")
+ logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
- logInfo("Received [" + gB + "]")
+ logDebug("Received [" + gB + "]")
val buffer = getBlock(gB.id)
if (buffer == null) {
return None
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
index ff482ff66b..1286600cd1 100644
--- a/core/src/main/scala/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -15,13 +15,14 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
* This is used to efficiently write the values to multiple locations (e.g. for replication).
*
- * @return the values put if returnValues is true, or null otherwise
+ * @return a PutResult that contains the size of the data, as well as the values put if
+ * returnValues is true (if not, the result's data field can be null)
*/
def putValues(blockId: String, values: Iterator[Any], level: StorageLevel, returnValues: Boolean)
- : Either[Iterator[Any], ByteBuffer]
+ : PutResult
/**
- * Return the size of a block.
+ * Return the size of a block in bytes.
*/
def getSize(blockId: String): Long
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index d0c592ccb1..fd92a3dc67 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -48,28 +48,28 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
values: Iterator[Any],
level: StorageLevel,
returnValues: Boolean)
- : Either[Iterator[Any], ByteBuffer] = {
+ : PutResult = {
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(
+ val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values)
objOut.close()
- val finishTime = System.currentTimeMillis
+ val length = file.length()
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(file.length()), (finishTime - startTime)))
+ blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
if (returnValues) {
// Return a byte buffer for the contents of the file
val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = channel.map(MapMode.READ_ONLY, 0, channel.size())
+ val buffer = channel.map(MapMode.READ_ONLY, 0, length)
channel.close()
- Right(buffer)
+ PutResult(length, Right(buffer))
} else {
- null
+ PutResult(length, null)
}
}
@@ -83,7 +83,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getValues(blockId: String): Option[Iterator[Any]] = {
- getBytes(blockId).map(blockManager.dataDeserialize(_))
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
override def remove(blockId: String) {
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 74ef326038..e9288fdf43 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -31,7 +31,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
if (level.deserialized) {
bytes.rewind()
- val values = blockManager.dataDeserialize(bytes)
+ val values = blockManager.dataDeserialize(blockId, bytes)
val elements = new ArrayBuffer[Any]
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
@@ -49,18 +49,18 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
values: Iterator[Any],
level: StorageLevel,
returnValues: Boolean)
- : Either[Iterator[Any], ByteBuffer] = {
+ : PutResult = {
if (level.deserialized) {
val elements = new ArrayBuffer[Any]
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
tryToPut(blockId, elements, sizeEstimate, true)
- Left(elements.iterator)
+ PutResult(sizeEstimate, Left(elements.iterator))
} else {
- val bytes = blockManager.dataSerialize(values)
+ val bytes = blockManager.dataSerialize(blockId, values)
tryToPut(blockId, bytes, bytes.limit, false)
- Right(bytes)
+ PutResult(bytes.limit(), Right(bytes))
}
}
@@ -71,7 +71,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (entry == null) {
None
} else if (entry.deserialized) {
- Some(blockManager.dataSerialize(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator))
+ Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator))
} else {
Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
}
@@ -87,7 +87,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)
} else {
val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data
- Some(blockManager.dataDeserialize(buffer))
+ Some(blockManager.dataDeserialize(blockId, buffer))
}
}
@@ -162,7 +162,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* don't fit into memory that we want to avoid).
*
- * Assumes that a lock on entries is held by the caller.
+ * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space
+ * might fill up before the caller puts in their new value.)
*/
private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
@@ -172,7 +173,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit")
return false
}
-
+
+ // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks
+ // in order to allow parallelism in writing to disk
if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
val selectedBlocks = new ArrayBuffer[String]()
diff --git a/core/src/main/scala/spark/storage/PutResult.scala b/core/src/main/scala/spark/storage/PutResult.scala
new file mode 100644
index 0000000000..76f236057b
--- /dev/null
+++ b/core/src/main/scala/spark/storage/PutResult.scala
@@ -0,0 +1,9 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+/**
+ * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the
+ * values put if the caller asked for them to be returned (e.g. for chaining replication)
+ */
+private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer])
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index 2237ce92b3..2d52fac1ef 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -2,7 +2,7 @@ package spark.storage
import java.io.{Externalizable, ObjectInput, ObjectOutput}
-private[spark] class StorageLevel(
+class StorageLevel(
var useDisk: Boolean,
var useMemory: Boolean,
var deserialized: Boolean,
@@ -63,7 +63,7 @@ private[spark] class StorageLevel(
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
}
-private[spark] object StorageLevel {
+object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
val DISK_ONLY_2 = new StorageLevel(true, false, false, 2)
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index f670ccb709..b466b5239c 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -23,6 +23,8 @@ private[spark] object AkkaUtils {
* ActorSystem itself and its port (which is hard to get from Akka).
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
+ val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
+ val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@@ -31,9 +33,9 @@ private[spark] object AkkaUtils {
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = 1s
- akka.remote.netty.execution-pool-size = 8
- akka.actor.default-dispatcher.throughput = 30
- """.format(host, port))
+ akka.remote.netty.execution-pool-size = %d
+ akka.actor.default-dispatcher.throughput = %d
+ """.format(host, port, akkaThreads, akkaBatchSize))
val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)