aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala88
-rw-r--r--core/src/main/scala/spark/Aggregator.scala43
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala45
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala6
-rw-r--r--core/src/main/scala/spark/Cache.scala10
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala57
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala6
-rw-r--r--core/src/main/scala/spark/Dependency.scala41
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala18
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala2
-rw-r--r--core/src/main/scala/spark/HadoopWriter.scala17
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala47
-rw-r--r--core/src/main/scala/spark/HttpServer.scala4
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala12
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala31
-rw-r--r--core/src/main/scala/spark/Logging.scala24
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala158
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala328
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala4
-rw-r--r--core/src/main/scala/spark/Partitioner.scala15
-rw-r--r--core/src/main/scala/spark/RDD.scala254
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala13
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala12
-rw-r--r--core/src/main/scala/spark/ShuffleManager.scala98
-rw-r--r--core/src/main/scala/spark/ShuffledRDD.scala51
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala28
-rw-r--r--core/src/main/scala/spark/SoftReferenceCache.scala2
-rw-r--r--core/src/main/scala/spark/SparkContext.scala337
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala86
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala14
-rw-r--r--core/src/main/scala/spark/TaskState.scala2
-rw-r--r--core/src/main/scala/spark/Utils.scala139
-rw-r--r--core/src/main/scala/spark/api/java/JavaDoubleRDD.scala34
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala243
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala25
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala128
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala117
-rw-r--r--core/src/main/scala/spark/api/java/StorageLevels.java20
-rw-r--r--core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java3
-rw-r--r--core/src/main/scala/spark/api/java/function/DoubleFunction.java3
-rw-r--r--core/src/main/scala/spark/api/java/function/FlatMapFunction.scala3
-rw-r--r--core/src/main/scala/spark/api/java/function/Function.java5
-rw-r--r--core/src/main/scala/spark/api/java/function/Function2.java3
-rw-r--r--core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java4
-rw-r--r--core/src/main/scala/spark/api/java/function/PairFunction.java3
-rw-r--r--core/src/main/scala/spark/api/java/function/VoidFunction.scala3
-rw-r--r--core/src/main/scala/spark/api/java/function/WrappedFunction1.scala2
-rw-r--r--core/src/main/scala/spark/api/java/function/WrappedFunction2.scala2
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala79
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala23
-rw-r--r--core/src/main/scala/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala52
-rw-r--r--core/src/main/scala/spark/broadcast/MultiTracker.scala53
-rw-r--r--core/src/main/scala/spark/broadcast/SourceInfo.scala9
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala75
-rw-r--r--core/src/main/scala/spark/deploy/Command.scala2
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/spark/deploy/ExecutorState.scala2
-rw-r--r--core/src/main/scala/spark/deploy/JobDescription.scala2
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala58
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala14
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestExecutor.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/ExecutorInfo.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/JobInfo.scala10
-rw-r--r--core/src/main/scala/spark/deploy/master/JobState.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala41
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterArguments.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala5
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala2
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala62
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala24
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala13
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala3
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala84
-rw-r--r--core/src/main/scala/spark/executor/ExecutorBackend.scala2
-rw-r--r--core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala14
-rw-r--r--core/src/main/scala/spark/executor/MesosExecutorBackend.scala4
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala4
-rw-r--r--core/src/main/scala/spark/network/Connection.scala32
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala19
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala2
-rw-r--r--core/src/main/scala/spark/network/Message.scala19
-rw-r--r--core/src/main/scala/spark/network/ReceiverTest.scala2
-rw-r--r--core/src/main/scala/spark/network/SenderTest.scala2
-rw-r--r--core/src/main/scala/spark/package.scala15
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala2
-rw-r--r--core/src/main/scala/spark/partial/ApproximateEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/CountEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/GroupedCountEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/GroupedSumEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/MeanEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/partial/StudentTCacher.scala2
-rw-r--r--core/src/main/scala/spark/partial/SumEvaluator.scala2
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala (renamed from core/src/main/scala/spark/BlockRDD.scala)12
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala (renamed from core/src/main/scala/spark/CartesianRDD.scala)11
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala (renamed from core/src/main/scala/spark/CoGroupedRDD.scala)44
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala (renamed from core/src/main/scala/spark/HadoopRDD.scala)10
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala21
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala (renamed from core/src/main/scala/spark/NewHadoopRDD.scala)24
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala (renamed from core/src/main/scala/spark/PipedRDD.scala)8
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala (renamed from core/src/main/scala/spark/SampledRDD.scala)31
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala40
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala (renamed from core/src/main/scala/spark/UnionRDD.scala)12
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala127
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala13
-rw-r--r--core/src/main/scala/spark/scheduler/JobListener.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/JobResult.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/JobWaiter.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala90
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala86
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala30
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala13
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala1
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala63
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala22
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala22
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala (renamed from core/src/main/scala/spark/Serializer.scala)22
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala729
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala250
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala44
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala14
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala8
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala313
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala180
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala218
-rw-r--r--core/src/main/scala/spark/storage/PutResult.scala9
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala25
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala10
-rw-r--r--core/src/main/scala/spark/util/ByteBufferInputStream.scala41
-rw-r--r--core/src/main/scala/spark/util/IntParam.scala2
-rw-r--r--core/src/main/scala/spark/util/MemoryParam.scala2
-rw-r--r--core/src/main/scala/spark/util/SerializableBuffer.scala1
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala23
-rw-r--r--core/src/main/twirl/spark/deploy/common/layout.scala.html (renamed from core/src/main/twirl/common/layout.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html (renamed from core/src/main/twirl/masterui/executor_row.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/executors_table.scala.html (renamed from core/src/main/twirl/masterui/executors_table.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html (renamed from core/src/main/twirl/masterui/index.scala.html)6
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_details.scala.html (renamed from core/src/main/twirl/masterui/job_details.scala.html)4
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_row.scala.html (renamed from core/src/main/twirl/masterui/job_row.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_table.scala.html (renamed from core/src/main/twirl/masterui/job_table.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html (renamed from core/src/main/twirl/masterui/worker_row.scala.html)7
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_table.scala.html (renamed from core/src/main/twirl/masterui/worker_table.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/worker/executor_row.scala.html (renamed from core/src/main/twirl/workerui/executor_row.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/worker/executors_table.scala.html (renamed from core/src/main/twirl/workerui/executors_table.scala.html)0
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html (renamed from core/src/main/twirl/workerui/index.scala.html)7
169 files changed, 4227 insertions, 1994 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index d764ffc29d..bacd0ace37 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -3,18 +3,20 @@ package spark
import java.io._
import scala.collection.mutable.Map
+import scala.collection.generic.Growable
/**
- * A datatype that can be accumulated, i.e. has an commutative and associative +.
+ * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
+ * but where the result type, `R`, may be different from the element type being added, `T`.
*
- * You must define how to add data, and how to merge two of these together. For some datatypes, these might be
- * the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't
- * always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you
- * will union two sets together.
+ * You must define how to add data, and how to merge two of these together. For some datatypes,
+ * such as a counter, these might be the same operation. In that case, you can use the simpler
+ * [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
+ * accumulating a set. You will add items to the set, and you will union two sets together.
*
* @param initialValue initial value of accumulator
- * @param param helper object defining how to add elements of type `T`
- * @tparam R the full accumulated data
+ * @param param helper object defining how to add elements of type `R` and `T`
+ * @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
@@ -43,13 +45,29 @@ class Accumulable[R, T] (
* @param term the other Accumulable that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
+ * Access the accumulator's current value; only allowed on master.
+ */
def value = {
if (!deserialized) value_
else throw new UnsupportedOperationException("Can't read accumulator value in task")
}
- private[spark] def localValue = value_
+ /**
+ * Get the current value of this accumulator from within a task.
+ *
+ * This is NOT the global value of the accumulator. To get the global value after a
+ * completed operation on the dataset, call `value`.
+ *
+ * The typical use of this method is to directly mutate the local value, eg., to add
+ * an element to a Set.
+ */
+ def localValue = value_
+ /**
+ * Set the accumulator's value; only allowed on master.
+ */
def value_= (r: R) {
if (!deserialized) value_ = r
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
@@ -67,31 +85,64 @@ class Accumulable[R, T] (
}
/**
- * Helper object defining how to accumulate values of a particular type.
+ * Helper object defining how to accumulate values of a particular type. An implicit
+ * AccumulableParam needs to be available when you create Accumulables of a specific type.
*
- * @tparam R the full accumulated data
+ * @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
trait AccumulableParam[R, T] extends Serializable {
/**
- * Add additional data to the accumulator value.
+ * Add additional data to the accumulator value. Is allowed to modify and return `r`
+ * for efficiency (to avoid allocating objects).
+ *
* @param r the current value of the accumulator
* @param t the data to be added to the accumulator
* @return the new value of the accumulator
*/
- def addAccumulator(r: R, t: T) : R
+ def addAccumulator(r: R, t: T): R
/**
- * Merge two accumulated values together
+ * Merge two accumulated values together. Is allowed to modify and return the first value
+ * for efficiency (to avoid allocating objects).
+ *
* @param r1 one set of accumulated data
* @param r2 another set of accumulated data
* @return both data sets merged together
*/
def addInPlace(r1: R, r2: R): R
+ /**
+ * Return the "zero" (identity) value for an accumulator type, given its initial value. For
+ * example, if R was a vector of N dimensions, this would return a vector of N zeroes.
+ */
def zero(initialValue: R): R
}
+private[spark]
+class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ extends AccumulableParam[R,T] {
+
+ def addAccumulator(growable: R, elem: T): R = {
+ growable += elem
+ growable
+ }
+
+ def addInPlace(t1: R, t2: R): R = {
+ t1 ++= t2
+ t1
+ }
+
+ def zero(initialValue: R): R = {
+ // We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
+ // Instead we'll serialize it to a buffer and load it back.
+ val ser = (new spark.JavaSerializer).newInstance()
+ val copy = ser.deserialize[R](ser.serialize(initialValue))
+ copy.clear() // In case it contained stuff
+ copy
+ }
+}
+
/**
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
@@ -100,17 +151,18 @@ trait AccumulableParam[R, T] extends Serializable {
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
-class Accumulator[T](
- @transient initialValue: T,
- param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param)
+class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
+ extends Accumulable[T,T](initialValue, param)
/**
* A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
- * as the accumulated value
+ * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create
+ * Accumulators of a specific type.
+ *
* @tparam T type of value to accumulate
*/
trait AccumulatorParam[T] extends AccumulableParam[T, T] {
- def addAccumulator(t1: T, t2: T) : T = {
+ def addAccumulator(t1: T, t2: T): T = {
addInPlace(t1, t2)
}
}
diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala
index 6f99270b1e..df8ce9c054 100644
--- a/core/src/main/scala/spark/Aggregator.scala
+++ b/core/src/main/scala/spark/Aggregator.scala
@@ -1,7 +1,44 @@
package spark
-class Aggregator[K, V, C] (
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConversions._
+
+/** A set of functions used to aggregate data.
+ *
+ * @param createCombiner function to create the initial value of the aggregation.
+ * @param mergeValue function to merge a new value into the aggregation result.
+ * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+ */
+case class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
- val mergeCombiners: (C, C) => C)
- extends Serializable
+ val mergeCombiners: (C, C) => C) {
+
+ def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = {
+ val combiners = new JHashMap[K, C]
+ for ((k, v) <- iter) {
+ val oldC = combiners.get(k)
+ if (oldC == null) {
+ combiners.put(k, createCombiner(v))
+ } else {
+ combiners.put(k, mergeValue(oldC, v))
+ }
+ }
+ combiners.iterator
+ }
+
+ def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
+ val combiners = new JHashMap[K, C]
+ for ((k, c) <- iter) {
+ val oldC = combiners.get(k)
+ if (oldC == null) {
+ combiners.put(k, c)
+ } else {
+ combiners.put(k, mergeCombiners(oldC, c))
+ }
+ }
+ combiners.iterator
+ }
+}
+
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 0bbdb4e432..86432d0127 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,52 +1,43 @@
package spark
-import java.io.EOFException
-import java.net.URL
-
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.BlockException
import spark.storage.BlockManagerId
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-
-class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
+ override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
- val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
- for ((address, index) <- addresses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
+ for (((address, size), index) <- statuses.zipWithIndex) {
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
+ (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
- for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
+ def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
+ val blockId = blockPair._1
+ val blockOption = blockPair._2
blockOption match {
case Some(block) => {
- val values = block
- for(value <- values) {
- val v = value.asInstanceOf[(K, V)]
- func(v._1, v._2)
- }
+ block.asInstanceOf[Iterator[(K, V)]]
}
case None => {
- val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r
+ val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, reduceId) =>
- val addr = addresses(mapId.toInt)
- throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
+ case regex(shufId, mapId, _) =>
+ val address = statuses(mapId.toInt)._1
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
@@ -54,8 +45,6 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
}
}
}
-
- logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
+ blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
}
}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index 6fe0b94297..e8392a194f 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -9,7 +9,7 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
-class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
+private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
@@ -104,9 +104,9 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
}
// An entry in our map; stores a cached object and its size in bytes
-case class Entry(value: Any, size: Long)
+private[spark] case class Entry(value: Any, size: Long)
-object BoundedMemoryCache {
+private[spark] object BoundedMemoryCache {
/**
* Get maximum cache capacity from system configuration
*/
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
index 150fe14e2c..20d677a854 100644
--- a/core/src/main/scala/spark/Cache.scala
+++ b/core/src/main/scala/spark/Cache.scala
@@ -2,9 +2,9 @@ package spark
import java.util.concurrent.atomic.AtomicInteger
-sealed trait CachePutResponse
-case class CachePutSuccess(size: Long) extends CachePutResponse
-case class CachePutFailure() extends CachePutResponse
+private[spark] sealed trait CachePutResponse
+private[spark] case class CachePutSuccess(size: Long) extends CachePutResponse
+private[spark] case class CachePutFailure() extends CachePutResponse
/**
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
@@ -22,7 +22,7 @@ case class CachePutFailure() extends CachePutResponse
* This abstract class handles the creation of key spaces, so that subclasses need only deal with
* keys that are unique across modules.
*/
-abstract class Cache {
+private[spark] abstract class Cache {
private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
@@ -52,7 +52,7 @@ abstract class Cache {
/**
* A key namespace in a Cache.
*/
-class KeySpace(cache: Cache, val keySpaceId: Int) {
+private[spark] class KeySpace(cache: Cache, val keySpaceId: Int) {
def get(datasetId: Any, partition: Int): Any =
cache.get((keySpaceId, datasetId), partition)
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 22110832f8..c5db6ce63a 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -15,19 +15,20 @@ import scala.collection.mutable.HashSet
import spark.storage.BlockManager
import spark.storage.StorageLevel
-sealed trait CacheTrackerMessage
-case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
+private[spark] sealed trait CacheTrackerMessage
+
+private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
-case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
+private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
-case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-case object GetCacheStatus extends CacheTrackerMessage
-case object GetCacheLocations extends CacheTrackerMessage
-case object StopCacheTracker extends CacheTrackerMessage
-
-class CacheTrackerActor extends Actor with Logging {
+private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
+private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
+private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
+private[spark] case object GetCacheStatus extends CacheTrackerMessage
+private[spark] case object GetCacheLocations extends CacheTrackerMessage
+private[spark] case object StopCacheTracker extends CacheTrackerMessage
+
+private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
@@ -43,8 +44,6 @@ class CacheTrackerActor extends Actor with Logging {
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
- logInfo("Started slave cache (size %s) on %s".format(
- Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
@@ -56,22 +55,12 @@ class CacheTrackerActor extends Actor with Logging {
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
- logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
- logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
- rddId, partition, host, Utils.memoryBytesToString(size),
- Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
- val usage = getCacheUsage(host)
- if (usage < 0) {
- logError("Cache usage on %s is negative (%d)".format(host, usage))
- }
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
@@ -101,7 +90,7 @@ class CacheTrackerActor extends Actor with Logging {
}
}
-class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
+private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
@@ -151,7 +140,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
- logInfo(RegisterRDD(rddId, numPartitions) + " successful")
}
}
}
@@ -169,9 +157,8 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
}
// For BlockManager.scala only
- def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
+ def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
- logInfo("notifyTheCacheTrackerFromBlockManager successful")
}
// Get a snapshot of the currently known locations
@@ -181,7 +168,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
- val key = "rdd:%d:%d".format(rdd.id, split.index)
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
@@ -221,23 +208,19 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// TODO: fetch any remote copy of the split that may be available
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
+ val elements = new ArrayBuffer[Any]
+ elements ++= rdd.compute(split)
try {
- // BlockManager will iterate over results from compute to create RDD
- blockManager.put(key, rdd.compute(split), storageLevel, false)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
//future.apply() // Wait for the reply from the cache tracker
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logWarning("loading partition failed after computing it " + key)
- return null
- }
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
+ return elements.iterator.asInstanceOf[Iterator[T]]
}
}
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
index 3b83d23a13..98525b99c8 100644
--- a/core/src/main/scala/spark/ClosureCleaner.scala
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -9,7 +9,7 @@ import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
-object ClosureCleaner extends Logging {
+private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream(
@@ -154,7 +154,7 @@ object ClosureCleaner extends Logging {
}
}
-class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
@@ -180,7 +180,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor
}
}
-class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index c0ff94acc6..b85d2732db 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -1,22 +1,51 @@
package spark
-abstract class Dependency[T](val rdd: RDD[T], val isShuffle: Boolean) extends Serializable
+/**
+ * Base class for dependencies.
+ */
+abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
-abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
+/**
+ * Base class for dependencies where each partition of the parent RDD is used by at most one
+ * partition of the child RDD. Narrow dependencies allow for pipelined execution.
+ */
+abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+ /**
+ * Get the parent partitions for a child partition.
+ * @param outputPartition a partition of the child RDD
+ * @return the partitions of the parent RDD that the child partition depends upon
+ */
def getParents(outputPartition: Int): Seq[Int]
}
-class ShuffleDependency[K, V, C](
- val shuffleId: Int,
+/**
+ * Represents a dependency on the output of a shuffle stage.
+ * @param shuffleId the shuffle id
+ * @param rdd the parent RDD
+ * @param partitioner partitioner used to partition the shuffle output
+ */
+class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
- extends Dependency(rdd, true)
+ extends Dependency(rdd) {
+ val shuffleId: Int = rdd.context.newShuffleId()
+}
+
+/**
+ * Represents a one-to-one dependency between partitions of the parent and child RDDs.
+ */
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
+/**
+ * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
+ * @param rdd the parent RDD
+ * @param inStart the start of the range in the parent RDD
+ * @param outStart the start of the range in the child RDD
+ * @param length the length of the range
+ */
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
index 1fbf66b7de..b2a0e2b631 100644
--- a/core/src/main/scala/spark/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -4,33 +4,49 @@ import spark.partial.BoundedDouble
import spark.partial.MeanEvaluator
import spark.partial.PartialResult
import spark.partial.SumEvaluator
-
import spark.util.StatCounter
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
+ * Import `spark.SparkContext._` at the top of your program to use these functions.
*/
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
+ /** Add up the elements in this RDD. */
def sum(): Double = {
self.reduce(_ + _)
}
+ /**
+ * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count
+ * of the RDD's elements in one operation.
+ */
def stats(): StatCounter = {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
}
+ /** Compute the mean of this RDD's elements. */
def mean(): Double = stats().mean
+ /** Compute the variance of this RDD's elements. */
def variance(): Double = stats().variance
+ /** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = stats().stdev
+ /**
+ * Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ * estimating the standard deviation by dividing by N-1 instead of N).
+ */
+ def sampleStdev(): Double = stats().stdev
+
+ /** (Experimental) Approximate operation to return the mean within a timeout. */
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new MeanEvaluator(self.splits.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+ /** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new SumEvaluator(self.splits.size, confidence)
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index 55512f4481..a953081d24 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -2,7 +2,7 @@ package spark
import spark.storage.BlockManagerId
-class FetchFailedException(
+private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala
index 12b6a0954c..afcf9f6db4 100644
--- a/core/src/main/scala/spark/HadoopWriter.scala
+++ b/core/src/main/scala/spark/HadoopWriter.scala
@@ -16,11 +16,14 @@ import spark.Logging
import spark.SerializableWritable
/**
- * Saves an RDD using a Hadoop OutputFormat as specified by a JobConf. The JobConf should also
- * contain an output key class, an output value class, a filename to write to, etc exactly like in
- * a Hadoop job.
+ * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public
+ * because we need to access this class from the `spark` package to use some package-private Hadoop
+ * functions, but this class should not be used directly by users.
+ *
+ * Saves the RDD using a JobConf, which should contain an output key class, an output value class,
+ * a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializable {
+class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -42,7 +45,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl
setConfParams()
val jCtxt = getJobContext()
- getOutputCommitter().setupJob(jCtxt)
+ getOutputCommitter().setupJob(jCtxt)
}
@@ -126,14 +129,14 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl
private def getJobContext(): JobContext = {
if (jobContext == null) {
- jobContext = new JobContext(conf.value, jID.value)
+ jobContext = newJobContext(conf.value, jID.value)
}
return jobContext
}
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
- taskContext = new TaskAttemptContext(conf.value, taID.value)
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
}
return taskContext
}
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
new file mode 100644
index 0000000000..659d17718f
--- /dev/null
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -0,0 +1,47 @@
+package spark
+
+import java.io.{File, PrintWriter}
+import java.net.URL
+import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.FileUtil
+
+private[spark] class HttpFileServer extends Logging {
+
+ var baseDir : File = null
+ var fileDir : File = null
+ var jarDir : File = null
+ var httpServer : HttpServer = null
+ var serverUri : String = null
+
+ def initialize() {
+ baseDir = Utils.createTempDir()
+ fileDir = new File(baseDir, "files")
+ jarDir = new File(baseDir, "jars")
+ fileDir.mkdir()
+ jarDir.mkdir()
+ logInfo("HTTP File server directory is " + baseDir)
+ httpServer = new HttpServer(baseDir)
+ httpServer.start()
+ serverUri = httpServer.uri
+ }
+
+ def stop() {
+ httpServer.stop()
+ }
+
+ def addFile(file: File) : String = {
+ addFileToDir(file, fileDir)
+ return serverUri + "/files/" + file.getName
+ }
+
+ def addJar(file: File) : String = {
+ addFileToDir(file, jarDir)
+ return serverUri + "/jars/" + file.getName
+ }
+
+ def addFileToDir(file: File, dir: File) : String = {
+ Utils.copyFile(file, new File(dir, file.getName))
+ return dir + "/" + file.getName
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
index 855f2c752f..0196595ba1 100644
--- a/core/src/main/scala/spark/HttpServer.scala
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -12,14 +12,14 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
-class ServerStateException(message: String) extends Exception(message)
+private[spark] class ServerStateException(message: String) extends Exception(message)
/**
* An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File) extends Logging {
private var server: Server = null
private var port: Int = -1
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index d11ba5167d..b04a27d073 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -3,16 +3,17 @@ package spark
import java.io._
import java.nio.ByteBuffer
+import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream}
import spark.util.ByteBufferInputStream
-class JavaSerializationStream(out: OutputStream) extends SerializationStream {
+private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
- def writeObject[T](t: T) { objOut.writeObject(t) }
+ def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
-class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
+private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
@@ -23,7 +24,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-class JavaSerializerInstance extends SerializerInstance {
+private[spark] class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -57,6 +58,9 @@ class JavaSerializerInstance extends SerializerInstance {
}
}
+/**
+ * A Spark serializer that uses Java's built-in serialization.
+ */
class JavaSerializer extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 8a3f565071..44b630e478 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -13,6 +13,7 @@ import com.esotericsoftware.kryo.serialize.ClassSerializer
import com.esotericsoftware.kryo.serialize.SerializableSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
+import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
@@ -20,7 +21,7 @@ import spark.storage._
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
*/
-object ZigZag {
+private[spark] object ZigZag {
def writeInt(n: Int, out: OutputStream) {
var value = n
if ((value & ~0x7F) == 0) {
@@ -68,22 +69,25 @@ object ZigZag {
}
}
+private[spark]
class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
- def writeObject[T](t: T) {
+ def writeObject[T](t: T): SerializationStream = {
kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out)
threadBuffer.flip()
channel.write(threadBuffer)
threadBuffer.clear()
+ this
}
def flush() { out.flush() }
def close() { out.close() }
}
+private[spark]
class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
@@ -94,7 +98,7 @@ extends DeserializationStream {
def close() { in.close() }
}
-class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
+private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val kryo = ks.kryo
val threadBuffer = ks.threadBuffer.get()
val objectBuffer = ks.objectBuffer.get()
@@ -155,13 +159,21 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
}
}
-// Used by clients to register their own classes
+/**
+ * Interface implemented by clients to register their classes with Kryo when using Kryo
+ * serialization.
+ */
trait KryoRegistrator {
def registerClasses(kryo: Kryo): Unit
}
-class KryoSerializer extends Serializer with Logging {
- val kryo = createKryo()
+/**
+ * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ */
+class KryoSerializer extends spark.serializer.Serializer with Logging {
+ // Make this lazy so that it only gets called once we receive our first task on each executor,
+ // so we can pull out any custom Kryo registrator from the user's JARs.
+ lazy val kryo = createKryo()
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
@@ -192,8 +204,8 @@ class KryoSerializer extends Serializer with Logging {
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
None,
ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY_DESER,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
+ StorageLevel.MEMORY_ONLY,
+ PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)),
GetBlock("1")
)
@@ -256,7 +268,8 @@ class KryoSerializer extends Serializer with Logging {
val regCls = System.getProperty("spark.kryo.registrator")
if (regCls != null) {
logInfo("Running user registrator: " + regCls)
- val reg = Class.forName(regCls).newInstance().asInstanceOf[KryoRegistrator]
+ val classLoader = Thread.currentThread.getContextClassLoader
+ val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo)
}
kryo
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 69935b86de..90bae26202 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -15,7 +15,7 @@ trait Logging {
private var log_ : Logger = null
// Method to get or create the logger for this object
- def log: Logger = {
+ protected def log: Logger = {
if (log_ == null) {
var className = this.getClass.getName
// Ignore trailing $'s in the class names for Scala objects
@@ -28,48 +28,48 @@ trait Logging {
}
// Log methods that take only a String
- def logInfo(msg: => String) {
+ protected def logInfo(msg: => String) {
if (log.isInfoEnabled) log.info(msg)
}
- def logDebug(msg: => String) {
+ protected def logDebug(msg: => String) {
if (log.isDebugEnabled) log.debug(msg)
}
- def logTrace(msg: => String) {
+ protected def logTrace(msg: => String) {
if (log.isTraceEnabled) log.trace(msg)
}
- def logWarning(msg: => String) {
+ protected def logWarning(msg: => String) {
if (log.isWarnEnabled) log.warn(msg)
}
- def logError(msg: => String) {
+ protected def logError(msg: => String) {
if (log.isErrorEnabled) log.error(msg)
}
// Log methods that take Throwables (Exceptions/Errors) too
- def logInfo(msg: => String, throwable: Throwable) {
+ protected def logInfo(msg: => String, throwable: Throwable) {
if (log.isInfoEnabled) log.info(msg, throwable)
}
- def logDebug(msg: => String, throwable: Throwable) {
+ protected def logDebug(msg: => String, throwable: Throwable) {
if (log.isDebugEnabled) log.debug(msg, throwable)
}
- def logTrace(msg: => String, throwable: Throwable) {
+ protected def logTrace(msg: => String, throwable: Throwable) {
if (log.isTraceEnabled) log.trace(msg, throwable)
}
- def logWarning(msg: => String, throwable: Throwable) {
+ protected def logWarning(msg: => String, throwable: Throwable) {
if (log.isWarnEnabled) log.warn(msg, throwable)
}
- def logError(msg: => String, throwable: Throwable) {
+ protected def logError(msg: => String, throwable: Throwable) {
if (log.isErrorEnabled) log.error(msg, throwable)
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
- def initLogging() { log }
+ protected def initLogging() { log }
}
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 82c1391345..45441aa5e5 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -1,6 +1,6 @@
package spark
-import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream}
+import java.io._
import java.util.concurrent.ConcurrentHashMap
import akka.actor._
@@ -14,16 +14,19 @@ import akka.util.duration._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scheduler.MapStatus
import spark.storage.BlockManagerId
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-sealed trait MapOutputTrackerMessage
-case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
-case object StopMapOutputTracker extends MapOutputTrackerMessage
+private[spark] sealed trait MapOutputTrackerMessage
+private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
+ extends MapOutputTrackerMessage
+private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
def receive = {
- case GetMapOutputLocations(shuffleId: Int) =>
- logInfo("Asked to get map output locations for shuffle " + shuffleId)
+ case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
+ logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker =>
@@ -33,23 +36,23 @@ class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Loggin
}
}
-class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
+private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
val timeout = 10.seconds
- var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
- private var generationLock = new java.lang.Object
+ private val generationLock = new java.lang.Object
- // Cache a serialized version of the output locations for each shuffle to send them out faster
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedLocs = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@@ -80,31 +83,34 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (bmAddresses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
+ mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
- def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ var array = mapStatuses.get(shuffleId)
array.synchronized {
- array(mapId) = bmAddress
+ array(mapId) = status
}
}
- def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
- bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
+ def registerMapOutputs(
+ shuffleId: Int,
+ statuses: Array[MapStatus],
+ changeGeneration: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeGeneration) {
incrementGeneration()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = bmAddresses.get(shuffleId)
+ var array = mapStatuses.get(shuffleId)
if (array != null) {
array.synchronized {
- if (array(mapId) == bmAddress) {
+ if (array(mapId).address == bmAddress) {
array(mapId) = null
}
}
@@ -117,10 +123,10 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
- // Called on possibly remote nodes to get the server URIs for a given shuffle
- def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
- val locs = bmAddresses.get(shuffleId)
- if (locs == null) {
+ // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
+ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ val statuses = mapStatuses.get(shuffleId)
+ if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
@@ -129,34 +135,38 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
try {
fetching.wait()
} catch {
- case _ =>
+ case e: InterruptedException =>
}
}
- return bmAddresses.get(shuffleId)
+ return mapStatuses.get(shuffleId).map(status =>
+ (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]]
- val fetchedLocs = deserializeLocations(fetchedBytes)
+ val host = System.getProperty("spark.hostname", Utils.localHostName)
+ val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
+ val fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
- bmAddresses.put(shuffleId, fetchedLocs)
+ mapStatuses.put(shuffleId, fetchedStatuses)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
- return fetchedLocs
+ return fetchedStatuses.map(s =>
+ (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
} else {
- return locs
+ return statuses.map(s =>
+ (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
}
}
def stop() {
communicate(StopMapOutputTracker)
- bmAddresses.clear()
+ mapStatuses.clear()
trackerActor = null
}
@@ -182,75 +192,83 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
+ mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
- var locs: Array[BlockManagerId] = null
+ var statuses: Array[MapStatus] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
- cachedSerializedLocs.clear()
+ cachedSerializedStatuses.clear()
cacheGeneration = generation
}
- cachedSerializedLocs.get(shuffleId) match {
+ cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
- locs = bmAddresses.get(shuffleId)
+ statuses = mapStatuses.get(shuffleId)
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeLocations(locs)
+ val bytes = serializeStatuses(statuses)
+ logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
- cachedSerializedLocs(shuffleId) = bytes
+ cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
}
// Serialize an array of map output locations into an efficient byte format so that we can send
- // it to reduce tasks. We do this by grouping together the locations by block manager ID.
- def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = {
+ // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
+ // generally be pretty compressible because many map outputs will be on the same hostname.
+ def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
- val dataOut = new DataOutputStream(out)
- dataOut.writeInt(locs.length)
- val grouped = locs.zipWithIndex.groupBy(_._1)
- dataOut.writeInt(grouped.size)
- for ((id, pairs) <- grouped if id != null) {
- dataOut.writeUTF(id.ip)
- dataOut.writeInt(id.port)
- dataOut.writeInt(pairs.length)
- for ((_, blockIndex) <- pairs) {
- dataOut.writeInt(blockIndex)
- }
- }
- dataOut.close()
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(statuses)
+ objOut.close()
out.toByteArray
}
- // Opposite of serializeLocations.
- def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = {
- val dataIn = new DataInputStream(new ByteArrayInputStream(bytes))
- val length = dataIn.readInt()
- val array = new Array[BlockManagerId](length)
- val numGroups = dataIn.readInt()
- for (i <- 0 until numGroups) {
- val ip = dataIn.readUTF()
- val port = dataIn.readInt()
- val id = new BlockManagerId(ip, port)
- val numBlocks = dataIn.readInt()
- for (j <- 0 until numBlocks) {
- array(dataIn.readInt()) = id
- }
+ // Opposite of serializeStatuses.
+ def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
+ /**
+ * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
+ * We do this by encoding the log base 1.1 of the size as an integer, which can support
+ * sizes up to 35 GB with at most 10% error.
+ */
+ def compressSize(size: Long): Byte = {
+ if (size <= 1L) {
+ 0
+ } else {
+ math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
+ }
+ }
+
+ /**
+ * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
+ */
+ def decompressSize(compressedSize: Byte): Long = {
+ if (compressedSize == 0) {
+ 1
+ } else {
+ math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
}
- array
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 64018f8c6b..e5bb639cfd 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -1,11 +1,6 @@
package spark
-import java.io.EOFException
-import java.net.URL
-import java.io.ObjectInputStream
-import java.util.concurrent.atomic.AtomicLong
-import java.util.{HashMap => JHashMap}
-import java.util.Date
+import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
import scala.collection.Map
@@ -15,46 +10,66 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.OutputFormat
-import org.apache.hadoop.mapred.SequenceFileOutputFormat
-import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
-import org.apache.hadoop.mapreduce.TaskAttemptID
-import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
-import spark.SparkContext._
import spark.partial.BoundedDouble
import spark.partial.PartialResult
+import spark.rdd._
+import spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
+ * Import `spark.SparkContext._` at the top of your program to use these functions.
*/
class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
+ with HadoopMapReduceUtil
with Serializable {
+ /**
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
+ * Note that V and C can be different -- for example, one might group an RDD of type
+ * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
+ *
+ * In addition, users can control the partitioning of the output RDD, and whether to perform
+ * map-side aggregation (if a mapper can produce multiple items with the same key).
+ */
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
- partitioner: Partitioner): RDD[(K, C)] = {
- val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
- new ShuffledRDD(self, aggregator, partitioner)
+ partitioner: Partitioner,
+ mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ val aggregator =
+ new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ if (mapSideCombine) {
+ val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
+ } else {
+ // Don't apply map-side combiner.
+ // A sanity check to make sure mergeCombiners is not defined.
+ assert(mergeCombiners == null)
+ val values = new ShuffledRDD[K, V](self, partitioner)
+ values.mapPartitions(aggregator.combineValuesByKey(_), true)
+ }
}
+ /**
+ * Simplified version of combineByKey that hash-partitions the output RDD.
+ */
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
@@ -62,10 +77,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
}
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, partitioner)
}
-
+
+ /**
+ * Merge the values for each key using an associative reduce function, but return the results
+ * immediately to the master as a Map. This will also perform the merging locally on each mapper
+ * before sending results to a reducer, similarly to a "combiner" in MapReduce.
+ */
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
@@ -87,22 +112,34 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self.mapPartitions(reducePartition).reduce(mergeMaps)
}
- // Alias for backwards compatibility
+ /** Alias for reduceByKeyLocally */
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
- // TODO: This should probably be a distributed version
+ /** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
- // TODO: This should probably be a distributed version
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[Map[K, BoundedDouble]] = {
self.map(_._1).countByValueApprox(timeout, confidence)
}
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ */
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
reduceByKey(new HashPartitioner(numSplits), func)
}
+ /**
+ * Group the values for each key in the RDD into a single sequence. Allows controlling the
+ * partitioning of the resulting key-value pair RDD by passing a Partitioner.
+ */
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@@ -112,19 +149,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with into `numSplits` partitions.
+ */
def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
groupByKey(new HashPartitioner(numSplits))
}
- def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
- def createCombiner(v: V) = ArrayBuffer(v)
- def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
- val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, mergeCombiners _, partitioner)
- bufs.flatMapValues(buf => buf)
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
+ * is true, Spark will group values of the same key together on the map side before the
+ * repartitioning, to only send each key over the network once. If a large number of
+ * duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should
+ * be set to true.
+ */
+ def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
+ if (mapSideCombine) {
+ def createCombiner(v: V) = ArrayBuffer(v)
+ def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
+ val bufs = combineByKey[ArrayBuffer[V]](
+ createCombiner _, mergeValue _, mergeCombiners _, partitioner)
+ bufs.flatMapValues(buf => buf)
+ } else {
+ new ShuffledRDD[K, V](self, partitioner)
+ }
}
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
case (vs, ws) =>
@@ -132,6 +189,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues {
case (vs, ws) =>
@@ -143,6 +206,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues {
@@ -155,56 +224,117 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
- def combineByKey[C](createCombiner: V => C,
- mergeValue: (C, V) => C,
- mergeCombiners: (C, C) => C) : RDD[(K, C)] = {
+ /**
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the default
+ * parallelism level.
+ */
+ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
}
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
+ */
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
reduceByKey(defaultPartitioner(self), func)
}
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with the default parallelism level.
+ */
def groupByKey(): RDD[(K, Seq[V])] = {
groupByKey(defaultPartitioner(self))
}
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
join(other, defaultPartitioner(self, other))
}
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
join(other, new HashPartitioner(numSplits))
}
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * using the default level of parallelism.
+ */
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, defaultPartitioner(self, other))
}
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * into `numSplits` partitions.
+ */
def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, new HashPartitioner(numSplits))
}
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD using the default parallelism level.
+ */
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, defaultPartitioner(self, other))
}
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD into the given number of partitions.
+ */
def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, new HashPartitioner(numSplits))
}
+ /**
+ * Return the key-value pairs in this RDD to the master as a Map.
+ */
def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
-
+
+ /**
+ * Pass each value in the key-value pair RDD through a map function without changing the keys;
+ * this also retains the original RDD's partitioning.
+ */
def mapValues[U](f: V => U): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new MappedValuesRDD(self, cleanF)
}
-
+
+ /**
+ * Pass each value in the key-value pair RDD through a flatMap function without changing the
+ * keys; this also retains the original RDD's partitioning.
+ */
def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new FlatMappedValuesRDD(self, cleanF)
}
-
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
@@ -215,12 +345,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
}
}
-
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
- other1.asInstanceOf[RDD[(_, _)]],
+ other1.asInstanceOf[RDD[(_, _)]],
other2.asInstanceOf[RDD[(_, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
@@ -230,28 +364,46 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, new HashPartitioner(numSplits))
}
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, new HashPartitioner(numSplits))
}
+ /** Alias for cogroup. */
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
+ /** Alias for cogroup. */
def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
@@ -268,6 +420,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
return new HashPartitioner(self.context.defaultParallelism)
}
+ /**
+ * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
+ * RDD has a known partitioner by only searching the partition that the key maps to.
+ */
def lookup(key: K): Seq[V] = {
self.partitioner match {
case Some(p) =>
@@ -286,14 +442,26 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD.
+ */
def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
-
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
+ * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
+ */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
+ * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
+ */
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
@@ -302,6 +470,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
}
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
+ * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
+ */
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
@@ -323,7 +495,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = new TaskAttemptID(jobtrackerID,
stageId, false, context.splitId, attemptNumber)
- val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId)
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
@@ -342,13 +514,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* setupJob/commitJob, so we just use a dummy "map" task.
*/
val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
- val jobTaskContext = new TaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
jobCommitter.cleanupJob(jobTaskContext)
}
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD.
+ */
def saveAsHadoopFile(
path: String,
keyClass: Class[_],
@@ -363,7 +539,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
}
-
+
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
+ * that storage system. The JobConf should set an OutputFormat and any output paths required
+ * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
+ * MapReduce job.
+ */
def saveAsHadoopDataset(conf: JobConf) {
val outputFormatClass = conf.getOutputFormat
val keyClass = conf.getOutputKeyClass
@@ -377,7 +559,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
if (valueClass == null) {
throw new SparkException("Output value class not set")
}
-
+
logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")")
val writer = new HadoopWriter(conf)
@@ -390,14 +572,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.setup(context.stageId, context.splitId, attemptNumber)
writer.open()
-
+
var count = 0
while(iter.hasNext) {
val record = iter.next
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
-
+
writer.close()
writer.commit()
}
@@ -406,35 +588,42 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
- def getKeyClass() = implicitly[ClassManifest[K]].erasure
+ private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
- def getValueClass() = implicitly[ClassManifest[V]].erasure
+ private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
}
+/**
+ * Extra functions available on RDDs of (key, value) pairs where the key is sortable through
+ * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these
+ * functions. They will work with any key type that has a `scala.math.Ordered` implementation.
+ */
class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
- extends Logging
+ extends Logging
with Serializable {
- def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
- val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending))
- new SortedRDD(rangePartitionedRDD, ascending)
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
+ val shuffled =
+ new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending))
+ shuffled.mapPartitions(iter => {
+ val buf = iter.toArray
+ if (ascending) {
+ buf.sortWith((x, y) => x._1 < y._1).iterator
+ } else {
+ buf.sortWith((x, y) => x._1 > y._1).iterator
+ }
+ }, true)
}
}
-class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
- extends RDD[(K, V)](prev.context) {
-
- override def splits = prev.splits
- override val partitioner = prev.partitioner
- override val dependencies = List(new OneToOneDependency(prev))
-
- override def compute(split: Split) = {
- prev.iterator(split).toArray
- .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
- }
-}
-
+private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
@@ -442,9 +631,10 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))}
}
+private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev.context) {
-
+
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
@@ -454,6 +644,6 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
}
}
-object Manifests {
+private[spark] object Manifests {
val seqSeqManifest = classManifest[Seq[Seq[_]]]
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index d79007ab40..9b57ae3b4f 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -3,7 +3,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
-class ParallelCollectionSplit[T: ClassManifest](
+private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
val slice: Int,
values: Seq[T])
@@ -21,7 +21,7 @@ class ParallelCollectionSplit[T: ClassManifest](
override val index: Int = slice
}
-class ParallelCollection[T: ClassManifest](
+private[spark] class ParallelCollection[T: ClassManifest](
sc: SparkContext,
@transient data: Seq[T],
numSlices: Int)
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index 643541429f..b71021a082 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -1,10 +1,17 @@
package spark
+/**
+ * An object that defines how the elements in a key-value pair RDD are partitioned by key.
+ * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
+ */
abstract class Partitioner extends Serializable {
def numPartitions: Int
def getPartition(key: Any): Int
}
+/**
+ * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ */
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
@@ -29,6 +36,10 @@ class HashPartitioner(partitions: Int) extends Partitioner {
}
}
+/**
+ * A [[spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
+ * Determines the ranges by sampling the RDD passed in.
+ */
class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
partitions: Int,
@transient rdd: RDD[(K,V)],
@@ -41,9 +52,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
Array()
} else {
val rddSize = rdd.count()
- val maxSampleSize = partitions * 10.0
+ val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
- val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
+ val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
Array()
} else {
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index d28f3593fe..338dff4061 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -31,51 +31,86 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
+import spark.rdd.BlockRDD
+import spark.rdd.CartesianRDD
+import spark.rdd.FilteredRDD
+import spark.rdd.FlatMappedRDD
+import spark.rdd.GlommedRDD
+import spark.rdd.MappedRDD
+import spark.rdd.MapPartitionsRDD
+import spark.rdd.MapPartitionsWithSplitRDD
+import spark.rdd.PipedRDD
+import spark.rdd.SampledRDD
+import spark.rdd.UnionRDD
import spark.storage.StorageLevel
import SparkContext._
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
- * partitioned collection of elements that can be operated on in parallel.
+ * partitioned collection of elements that can be operated on in parallel. This class contains the
+ * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
+ * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
+ * as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on
+ * RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs
+ * that can be saved as SequenceFiles. These operations are automatically available on any RDD of
+ * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you
+ * `import spark.SparkContext._`.
*
- * Each RDD is characterized by five main properties:
- * - A list of splits (partitions)
- * - A function for computing each split
- * - A list of dependencies on other RDDs
- * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
- * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
- * HDFS)
+ * Internally, each RDD is characterized by five main properties:
*
- * All the scheduling and execution in Spark is done based on these methods, allowing each RDD to
- * implement its own way of computing itself.
+ * - A list of splits (partitions)
+ * - A function for computing each split
+ * - A list of dependencies on other RDDs
+ * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
+ * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
+ * an HDFS file)
*
- * This class also contains transformation methods available on all RDDs (e.g. map and filter). In
- * addition, PairRDDFunctions contains extra methods available on RDDs of key-value pairs, and
- * SequenceFileRDDFunctions contains extra methods for saving RDDs to Hadoop SequenceFiles.
+ * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD
+ * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for
+ * reading data from a new storage system) by overriding these functions. Please refer to the
+ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
+ * on RDD internals.
*/
abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
- // Methods that must be implemented by subclasses
+ // Methods that must be implemented by subclasses:
+
+ /** Set of partitions in this RDD. */
def splits: Array[Split]
+
+ /** Function for computing a given partition. */
def compute(split: Split): Iterator[T]
+
+ /** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
+
+ // Methods available on all RDDs:
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
- // Optionally overridden by subclasses to specify how they are partitioned
+ /** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- // Optionally overridden by subclasses to specify placement preferences
+ /** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil
+ /** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
- // Get a unique ID for this RDD
+ /** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
- // Change this RDD's storage level
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
@@ -86,22 +121,23 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
this
}
- // Turn on the default caching level for this RDD
- def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
- // Turn on the default caching level for this RDD
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
+ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = {
+ private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
}
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
- def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
+ def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} )
@@ -113,7 +149,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
}
- // Read this RDD; will read from cache if applicable, or otherwise compute
+ /**
+ * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
+ * This should ''not'' be called by users directly, but is available for implementors of custom
+ * subclasses of RDD.
+ */
final def iterator(split: Split): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
@@ -124,15 +164,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Transformations (return a new RDD)
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
-
+
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
new FlatMappedRDD(this, sc.clean(f))
-
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
- def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1)
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numSplits: Int = splits.size): RDD[T] =
+ map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ /**
+ * Return a sampled subset of this RDD.
+ */
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
@@ -143,8 +200,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count()
var maxSelected = 0
- if (initialCount > Integer.MAX_VALUE) {
- maxSelected = Integer.MAX_VALUE
+ if (initialCount > Integer.MAX_VALUE - 1) {
+ maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
@@ -159,56 +216,109 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num
}
- var samples = this.sample(withReplacement, fraction, seed).collect()
+ val rand = new Random(seed)
+ var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
- samples = this.sample(withReplacement, fraction, seed).collect()
+ samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
- val arr = samples.take(total)
-
- return arr
+ Utils.randomizeInPlace(samples, rand).take(total)
}
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def ++(other: RDD[T]): RDD[T] = this.union(other)
+ /**
+ * Return an RDD created by coalescing all elements within each partition into an array.
+ */
def glom(): RDD[Array[T]] = new GlommedRDD(this)
+ /**
+ * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
+ * elements (a, b) where a is in `this` and b is in `other`.
+ */
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(numSplits)
}
+ /**
+ * Return an RDD of grouped items.
+ */
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism)
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: String): RDD[String] = new PipedRDD(this, command)
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
- def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
- new MapPartitionsRDD(this, sc.clean(f))
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
+ def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
+ new MapPartitionsWithSplitRDD(this, sc.clean(f))
// Actions (launch a job to return a value to the user program)
-
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
def collect(): Array[T] = {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
def toArray(): Array[T] = collect()
+ /**
+ * Reduces the elements of this RDD using the specified associative binary operator.
+ */
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -257,7 +367,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
}
-
+
+ /**
+ * Return the number of elements in the RDD.
+ */
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
var result = 0L
@@ -270,7 +383,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
- * Approximate version of count() that returns a potentially incomplete result after a timeout.
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
@@ -286,12 +400,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
- * Count elements equal to each value, returning a map of (value, count) pairs. The final combine
- * step happens locally on the master, equivalent to running a single reduce task.
- *
- * TODO: This should perhaps be distributed by default.
+ * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
+ * combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
+ // TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
while (iter.hasNext) {
@@ -313,7 +426,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
- * Approximate version of countByValue().
+ * (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(
timeout: Long,
@@ -353,18 +466,27 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
return buf.toArray
}
+ /**
+ * Return the first element in this RDD.
+ */
def first(): T = take(1) match {
case Array(t) => t
case _ => throw new UnsupportedOperationException("empty collection")
}
+ /**
+ * Save this RDD as a text file, using string representations of elements.
+ */
def saveAsTextFile(path: String) {
this.map(x => (NullWritable.get(), new Text(x.toString)))
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
+ /**
+ * Save this RDD as a SequenceFile of serialized objects.
+ */
def saveAsObjectFile(path: String) {
- this.glom
+ this.mapPartitions(iter => iter.grouped(10).map(_.toArray))
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
.saveAsSequenceFile(path)
}
@@ -374,45 +496,3 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
}
-
-class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => U)
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).map(f)
-}
-
-class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => TraversableOnce[U])
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).flatMap(f)
-}
-
-class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = prev.iterator(split).filter(f)
-}
-
-class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
-}
-
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: Iterator[T] => Iterator[U])
- extends RDD[U](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split) = f(prev.iterator(split))
-}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index ea7171d3a1..a34aee69c1 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -23,19 +23,21 @@ import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.Text
-import SparkContext._
+import spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
* through an implicit conversion. Note that this can't be part of PairRDDFunctions because
* we need more implicit parameters to convert our keys and values to Writable.
+ *
+ * Users should import `spark.SparkContext._` at the top of their program to use these functions.
*/
class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
- def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
+ private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
val c = {
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
@@ -47,6 +49,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
c.asInstanceOf[Class[_ <: Writable]]
}
+ /**
+ * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key
+ * and value types. If the key or value are Writable, then we use their classes directly;
+ * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc,
+ * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
+ * file system.
+ */
def saveAsSequenceFile(path: String) {
def anyToWritable[U <% Writable](u: U): Writable = u
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 4f8d98f7d0..d9a94d4021 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,10 +1,12 @@
package spark
-abstract class ShuffleFetcher {
- // Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly
- // once on each key-value pair obtained.
- def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit)
+private[spark] abstract class ShuffleFetcher {
+ /**
+ * Fetch the shuffle outputs for a given ShuffleDependency.
+ * @return An iterator over the elements of the fetched shuffle outputs.
+ */
+ def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
- // Stop the fetcher
+ /** Stop the fetcher */
def stop() {}
}
diff --git a/core/src/main/scala/spark/ShuffleManager.scala b/core/src/main/scala/spark/ShuffleManager.scala
deleted file mode 100644
index 24af7f3a08..0000000000
--- a/core/src/main/scala/spark/ShuffleManager.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-package spark
-
-import java.io._
-import java.net.URL
-import java.util.UUID
-import java.util.concurrent.atomic.AtomicLong
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-
-import spark._
-
-class ShuffleManager extends Logging {
- private var nextShuffleId = new AtomicLong(0)
-
- private var shuffleDir: File = null
- private var server: HttpServer = null
- private var serverUri: String = null
-
- initialize()
-
- private def initialize() {
- // TODO: localDir should be created by some mechanism common to Spark
- // so that it can be shared among shuffle, broadcast, etc
- val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
- var tries = 0
- var foundLocalDir = false
- var localDir: File = null
- var localDirUuid: UUID = null
- while (!foundLocalDir && tries < 10) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID
- localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed 10 attempts to create local dir in " + localDirRoot)
- System.exit(1)
- }
- shuffleDir = new File(localDir, "shuffle")
- shuffleDir.mkdirs()
- logInfo("Shuffle dir: " + shuffleDir)
-
- // Add a shutdown hook to delete the local dir
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dir") {
- override def run() {
- Utils.deleteRecursively(localDir)
- }
- })
-
- val extServerPort = System.getProperty(
- "spark.localFileShuffle.external.server.port", "-1").toInt
- if (extServerPort != -1) {
- // We're using an external HTTP server; set URI relative to its root
- var extServerPath = System.getProperty(
- "spark.localFileShuffle.external.server.path", "")
- if (extServerPath != "" && !extServerPath.endsWith("/")) {
- extServerPath += "/"
- }
- serverUri = "http://%s:%d/%s/spark-local-%s".format(
- Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
- } else {
- // Create our own server
- server = new HttpServer(localDir)
- server.start()
- serverUri = server.uri
- }
- logInfo("Local URI: " + serverUri)
- }
-
- def stop() {
- if (server != null) {
- server.stop()
- }
- }
-
- def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
- val dir = new File(shuffleDir, shuffleId + "/" + inputId)
- dir.mkdirs()
- val file = new File(dir, "" + outputId)
- return file
- }
-
- def getServerUri(): String = {
- serverUri
- }
-
- def newShuffleId(): Long = {
- nextShuffleId.getAndIncrement()
- }
-}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
deleted file mode 100644
index 594dbd235f..0000000000
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-package spark
-
-import java.util.{HashMap => JHashMap}
-
-class ShuffledRDDSplit(val idx: Int) extends Split {
- override val index = idx
- override def hashCode(): Int = idx
-}
-
-class ShuffledRDD[K, V, C](
- @transient parent: RDD[(K, V)],
- aggregator: Aggregator[K, V, C],
- part : Partitioner)
- extends RDD[(K, C)](parent.context) {
- //override val partitioner = Some(part)
- override val partitioner = Some(part)
-
- @transient
- val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def splits = splits_
-
- override def preferredLocations(split: Split) = Nil
-
- val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
- override val dependencies = List(dep)
-
- override def compute(split: Split): Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- def mergePair(k: K, c: C) {
- val oldC = combiners.get(k)
- if (oldC == null) {
- combiners.put(k, c)
- } else {
- combiners.put(k, aggregator.mergeCombiners(oldC, c))
- }
- }
- val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
- return new Iterator[(K, C)] {
- var iter = combiners.entrySet().iterator()
-
- def hasNext: Boolean = iter.hasNext()
-
- def next(): (K, C) = {
- val entry = iter.next()
- (entry.getKey, entry.getValue)
- }
- }
- }
-}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index e5ad8b52dc..7c3e8640e9 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -22,7 +22,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet
* Based on the following JavaWorld article:
* http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
*/
-object SizeEstimator extends Logging {
+private[spark] object SizeEstimator extends Logging {
// Sizes of primitive types
private val BYTE_SIZE = 1
@@ -77,22 +77,18 @@ object SizeEstimator extends Logging {
return System.getProperty("spark.test.useCompressedOops").toBoolean
}
try {
- val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic";
- val server = ManagementFactory.getPlatformMBeanServer();
+ val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
+ val server = ManagementFactory.getPlatformMBeanServer()
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
- hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]);
+ hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
return bean.getVMOption("UseCompressedOops").getValue.toBoolean
} catch {
- case e: IllegalArgumentException => {
- logWarning("Exception while trying to check if compressed oops is enabled", e)
- // Fall back to checking if maxMemory < 32GB
- return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
- }
-
- case e: SecurityException => {
- logWarning("No permission to create MBeanServer", e)
- // Fall back to checking if maxMemory < 32GB
- return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
+ case e: Exception => {
+ // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
+ val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
+ val guessInWords = if (guess) "yes" else "not"
+ logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
+ return guess
}
}
}
@@ -146,6 +142,10 @@ object SizeEstimator extends Logging {
val cls = obj.getClass
if (cls.isArray) {
visitArray(obj, cls, state)
+ } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
+ // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
+ // the size estimator since it references the whole REPL. Do nothing in this case. In
+ // general all ClassLoaders and Classes will be shared between objects anyway.
} else {
val classInfo = getClassInfo(cls)
state.size += classInfo.shellSize
diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala
index ce9370c5d7..3dd0a4b1f9 100644
--- a/core/src/main/scala/spark/SoftReferenceCache.scala
+++ b/core/src/main/scala/spark/SoftReferenceCache.scala
@@ -5,7 +5,7 @@ import com.google.common.collect.MapMaker
/**
* An implementation of Cache that uses soft references.
*/
-class SoftReferenceCache extends Cache {
+private[spark] class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]()
override def get(datasetId: Any, partition: Int): Any =
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 1d5131ad13..0d37075ef3 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -2,13 +2,15 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
+import java.net.{URI, URLClassLoader}
+
+import scala.collection.Map
+import scala.collection.generic.Growable
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.Actor
import akka.actor.Actor._
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -25,34 +27,59 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
-
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
-
import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import spark.broadcast._
-
+import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-
+import spark.rdd.HadoopRDD
+import spark.rdd.NewHadoopRDD
+import spark.rdd.UnionRDD
import spark.scheduler.ShuffleMapTask
import spark.scheduler.DAGScheduler
import spark.scheduler.TaskScheduler
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import spark.storage.BlockManagerMaster
+/**
+ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
+ * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI.
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param environment Environment variables to set on worker nodes.
+ */
class SparkContext(
val master: String,
- val frameworkName: String,
+ val jobName: String,
val sparkHome: String,
- val jars: Seq[String])
+ val jars: Seq[String],
+ environment: Map[String, String])
extends Logging {
-
- def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+ def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
+ this(master, jobName, sparkHome, jars, Map())
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ */
+ def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
// Ensure logging is initialized before we spawn any threads
initLogging()
@@ -68,46 +95,89 @@ class SparkContext(
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
- val env = SparkEnv.createFromSystemProperties(
+ private[spark] val env = SparkEnv.createFromSystemProperties(
System.getProperty("spark.master.host"),
System.getProperty("spark.master.port").toInt,
true,
isLocal)
SparkEnv.set(env)
+ // Used to store a URL for each static file/jar together with the file's local timestamp
+ private[spark] val addedFiles = HashMap[String, Long]()
+ private[spark] val addedJars = HashMap[String, Long]()
+
+ // Add each JAR given through the constructor
+ jars.foreach { addJar(_) }
+
+ // Environment variables to pass to our executors
+ private[spark] val executorEnvs = HashMap[String, String]()
+ for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
+ "SPARK_TESTING")) {
+ val value = System.getenv(key)
+ if (value != null) {
+ executorEnvs(key) = value
+ }
+ }
+ executorEnvs ++= environment
+
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+ // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
master match {
- case "local" =>
- new LocalScheduler(1, 0)
+ case "local" =>
+ new LocalScheduler(1, 0, this)
- case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0)
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt)
+ new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
scheduler.initialize(backend)
scheduler
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ val sparkMemEnv = System.getenv("SPARK_MEM")
+ val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
+ if (sparkMemEnvInt > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
+ memoryPerSlaveInt, sparkMemEnvInt))
+ }
+
+ val scheduler = new ClusterScheduler(this)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ val sparkUrl = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
+ scheduler.initialize(backend)
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
+ localCluster.stop()
+ }
+ scheduler
+
case _ =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
+ val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, master, frameworkName)
+ new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
} else {
- new MesosSchedulerBackend(scheduler, this, master, frameworkName)
+ new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
}
scheduler.initialize(backend)
scheduler
@@ -119,14 +189,20 @@ class SparkContext(
// Methods for creating RDDs
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
}
-
- def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits)
.map(pair => pair._2.toString)
@@ -163,19 +239,31 @@ class SparkContext(
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
- * values and the InputFormat so that users don't need to pass them directly.
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
+ * }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
- fm.erasure.asInstanceOf[Class[F]],
+ fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
minSplits)
}
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
+ * }}}
+ */
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
@@ -191,7 +279,7 @@ class SparkContext(
new Configuration)
}
- /**
+ /**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@@ -207,7 +295,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
}
- /**
+ /**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@@ -219,7 +307,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
- /** Get an RDD for a Hadoop SequenceFile with given key and value types */
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@@ -229,18 +317,23 @@ class SparkContext(
hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
}
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
- * Version of sequenceFile() for types implicitly convertible to Writables through a
- * WritableConverter.
+ * Version of sequenceFile() for types implicitly convertible to Writables through a
+ * WritableConverter. For example, to access a SequenceFile where the keys are Text and the
+ * values are IntWritable, you could simply write
+ * {{{
+ * sparkContext.sequenceFile[String, Int](path, ...)
+ * }}}
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
- * both subclasses of Writable and types for which we define a converter (e.g. Int to
+ * both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
- * have a parameterized singleton object). We use functions instead to create a new converter
+ * have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
@@ -265,7 +358,7 @@ class SparkContext(
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T: ClassManifest](
- path: String,
+ path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
@@ -275,43 +368,128 @@ class SparkContext(
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
- /** Build the union of a list of RDDs. */
+ /** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
+ /**
+ * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
+ * to using the `+=` method. Only the master can access the accumulator's `value`.
+ */
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
- * Create an accumulable shared variable, with a `+=` method
+ * Create an [[spark.Accumulable]] shared variable, with a `+=` method
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
+ /**
+ * Create an accumulator from a "mutable collection" type.
+ *
+ * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
+ * standard mutable collections. So you can use this with mutable Map, Set, etc.
+ */
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
+ val param = new GrowableAccumulableParam[R,T]
+ new Accumulable(initialValue, param)
+ }
+
+ /**
+ * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
+ * reading it in distributed functions. The variable will be sent to each cluster only once.
+ */
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
+
+ /**
+ * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addFile(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
+ case _ => path
+ }
+ addedFiles(key) = System.currentTimeMillis
- // Keep around a weak hash map of values to Cached versions?
- def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
+ // Fetch the file locally in case the task is executed locally
+ val filename = new File(path.split("/").last)
+ Utils.fetchFile(path, new File("."))
- // Stop the SparkContext
+ logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
+ }
+
+ /**
+ * Return a map from the slave to the max memory available for caching and the remaining
+ * memory available for caching.
+ */
+ def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
+ env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
+ (blockManagerId.ip + ":" + blockManagerId.port, mem)
+ }
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get donwloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedFiles.clear()
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
+ }
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedJars.clear()
+ }
+
+ /** Shut down the SparkContext. */
def stop() {
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext")
}
- // Get Spark's home location from either a value set through the constructor,
- // or the spark.home Java property, or the SPARK_HOME environment variable
- // (in that order of preference). If neither of these is set, return None.
- def getSparkHome(): Option[String] = {
+ /**
+ * Get Spark's home location from either a value set through the constructor,
+ * or the spark.home Java property, or the SPARK_HOME environment variable
+ * (in that order of preference). If neither of these is set, return None.
+ */
+ private[spark] def getSparkHome(): Option[String] = {
if (sparkHome != null) {
Some(sparkHome)
} else if (System.getProperty("spark.home") != null) {
@@ -326,7 +504,7 @@ class SparkContext(
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
- * whether the scheduler can run the computation on the master rather than shipping it out to the
+ * whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
@@ -335,22 +513,27 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- logInfo("Starting job...")
+ val callSite = Utils.getSparkCallSite
+ logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
- logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
+ logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
+ /**
+ * Run a job on a given set of partitions of an RDD, but take a function of type
+ * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ */
def runJob[T, U: ClassManifest](
rdd: RDD[T],
- func: Iterator[T] => U,
+ func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}
-
+
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
@@ -358,6 +541,9 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
+ /**
+ * Run a job on all partitions in an RDD and return the results in an array.
+ */
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.splits.size, false)
}
@@ -371,38 +557,37 @@ class SparkContext(
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] = {
- logInfo("Starting job...")
+ val callSite = Utils.getSparkCallSite
+ logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
- logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
+ logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
- // Clean a closure to make it ready to serialized and send to tasks
- // (removes unreferenced variables in $outer's, updates REPL variables)
+ /**
+ * Clean a closure to make it ready to serialized and send to tasks
+ * (removes unreferenced variables in $outer's, updates REPL variables)
+ */
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
- // Default level of parallelism to use when not given by user (e.g. for reduce tasks)
+ /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
- // Default min number of splits for Hadoop RDDs when not given by user
+ /** Default min number of splits for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private var nextShuffleId = new AtomicInteger(0)
- private[spark] def newShuffleId(): Int = {
- nextShuffleId.getAndIncrement()
- }
-
+ private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
+
private var nextRddId = new AtomicInteger(0)
- // Register a new RDD, returning its RDD ID
- private[spark] def newRddId(): Int = {
- nextRddId.getAndIncrement()
- }
+ /** Register a new RDD, returning its RDD ID */
+ private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
}
/**
@@ -429,7 +614,7 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
@@ -450,7 +635,7 @@ object SparkContext {
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
-
+
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
@@ -461,7 +646,7 @@ object SparkContext {
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
-
+
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
@@ -489,8 +674,10 @@ object SparkContext {
implicit def writableWritableConverter[T <: Writable]() =
new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
- // Find the JAR from which a given class was loaded, to make it easy for users to pass
- // their JARs to SparkContext
+ /**
+ * Find the JAR from which a given class was loaded, to make it easy for users to pass
+ * their JARs to SparkContext
+ */
def jarOfClass(cls: Class[_]): Seq[String] = {
val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
if (uri != null) {
@@ -505,8 +692,8 @@ object SparkContext {
Nil
}
}
-
- // Find the JAR that contains the class of a particular object
+
+ /** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
}
@@ -518,7 +705,7 @@ object SparkContext {
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
-class WritableConverter[T](
+private[spark] class WritableConverter[T](
val writableClass: ClassManifest[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index add8fcec51..4c6ec6cc6e 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,46 +1,57 @@
package spark
import akka.actor.ActorSystem
+import akka.actor.ActorSystemImpl
+import akka.remote.RemoteActorRefProvider
+import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.util.AkkaUtils
+/**
+ * Holds all the runtime environment objects for a running Spark instance (either master or worker),
+ * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
+ * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
+ * objects needs to have the right SparkEnv set. You can get the current environment with
+ * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ */
class SparkEnv (
val actorSystem: ActorSystem,
- val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
- val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
- val connectionManager: ConnectionManager
+ val connectionManager: ConnectionManager,
+ val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
- this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
+ this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() {
+ httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
- shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
actorSystem.shutdown()
+ // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
+ // down, but let's call it anyway in case it gets fixed in a later release
actorSystem.awaitTermination()
}
}
-object SparkEnv {
+object SparkEnv extends Logging {
private val env = new ThreadLocal[SparkEnv]
def set(e: SparkEnv) {
@@ -66,66 +77,55 @@ object SparkEnv {
System.setProperty("spark.master.port", boundPort.toString)
}
- val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
- val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+ val classLoader = Thread.currentThread.getContextClassLoader
+
+ // Create an instance of the class named by the given Java system property, or by
+ // defaultClassName if the property is not set, and return it as a T
+ def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
+ val name = System.getProperty(propertyName, defaultClassName)
+ Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ }
+
+ val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
-
val blockManager = new BlockManager(blockManagerMaster, serializer)
- val connectionManager = blockManager.connectionManager
-
- val shuffleManager = new ShuffleManager()
+ val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
- val closureSerializerClass =
- System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
- val closureSerializer =
- Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
- val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
- val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
- val shuffleFetcherClass =
- System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
- val shuffleFetcher =
- Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
-
- /*
- if (System.getProperty("spark.stream.distributed", "false") == "true") {
- val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
- if (isLocal || !isMaster) {
- (new Thread() {
- override def run() {
- println("Wait started")
- Thread.sleep(60000)
- println("Wait ended")
- val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
- val constructor = receiverClass.getConstructor(blockManagerClass)
- val receiver = constructor.newInstance(blockManager)
- receiver.asInstanceOf[Thread].start()
- }
- }).start()
- }
+ val shuffleFetcher = instantiateClass[ShuffleFetcher](
+ "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
+
+ val httpFileServer = new HttpFileServer()
+ httpFileServer.initialize()
+ System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+
+ // Warn about deprecated spark.cache.class property
+ if (System.getProperty("spark.cache.class") != null) {
+ logWarning("The spark.cache.class property is no longer being used! Specify storage " +
+ "levels using the RDD.persist() method instead.")
}
- */
new SparkEnv(
actorSystem,
- cache,
serializer,
closureSerializer,
cacheTracker,
mapOutputTracker,
shuffleFetcher,
- shuffleManager,
broadcastManager,
blockManager,
- connectionManager)
+ connectionManager,
+ httpFileServer)
}
}
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
index 6e4eb25ed4..420c54bc9a 100644
--- a/core/src/main/scala/spark/TaskEndReason.scala
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -7,10 +7,16 @@ import spark.storage.BlockManagerId
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
-sealed trait TaskEndReason
+private[spark] sealed trait TaskEndReason
-case object Success extends TaskEndReason
+private[spark] case object Success extends TaskEndReason
+
+private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
+
+private[spark]
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
-case class ExceptionFailure(exception: Throwable) extends TaskEndReason
-case class OtherFailure(message: String) extends TaskEndReason
+
+private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+
+private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/spark/TaskState.scala b/core/src/main/scala/spark/TaskState.scala
index 9566b52432..78eb33a628 100644
--- a/core/src/main/scala/spark/TaskState.scala
+++ b/core/src/main/scala/spark/TaskState.scala
@@ -2,7 +2,7 @@ package spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
-object TaskState
+private[spark] object TaskState
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 1d33f7d6b3..1bdde25896 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,18 +1,18 @@
package spark
import java.io._
-import java.net.InetAddress
+import java.net.{InetAddress, URL, URI}
+import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
-import java.util.{Locale, UUID}
import scala.io.Source
/**
* Various utility methods used by Spark.
*/
-object Utils {
+private object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -71,7 +71,7 @@ object Utils {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory after " + maxAttempts +
+ throw new IOException("Failed to create a temp directory after " + maxAttempts +
" attempts!")
}
try {
@@ -116,22 +116,84 @@ object Utils {
copyStream(in, out, true)
}
+ /** Download a file from a given URL to the local filesystem */
+ def downloadFile(url: URL, localPath: String) {
+ val in = url.openStream()
+ val out = new FileOutputStream(localPath)
+ Utils.copyStream(in, out, true)
+ }
+
+ /**
+ * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ */
+ def fetchFile(url: String, targetDir: File) {
+ val filename = url.split("/").last
+ val targetFile = new File(targetDir, filename)
+ val uri = new URI(url)
+ uri.getScheme match {
+ case "http" | "https" | "ftp" =>
+ logInfo("Fetching " + url + " to " + targetFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ case "file" | null =>
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally.
+ if (uri.isAbsolute) {
+ // url is absolute, i.e. it starts with "file:///". Extract the source
+ // file's absolute path from the url.
+ val sourceFile = new File(uri)
+ logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ } else {
+ // url is not absolute, i.e. itself is the path to the source file.
+ logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(url, targetFile.getAbsolutePath)
+ }
+ case _ =>
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val uri = new URI(url)
+ val conf = new Configuration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ }
+ // Decompress the file if it's a .tar or .tar.gz
+ if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xzf", filename), targetDir)
+ } else if (filename.endsWith(".tar")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ }
+ // Make the file executable - That's necessary for scripts
+ FileUtil.chmod(filename, "a+x")
+ }
+
/**
* Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
- def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
- val buf = new ArrayBuffer[T]()
- buf ++= seq
- val rand = new Random()
- for (i <- (buf.size - 1) to 1 by -1) {
+ def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ randomizeInPlace(seq.toArray)
+ }
+
+ /**
+ * Shuffle the elements of an array into a random order, modifying the
+ * original array. Returns the original array.
+ */
+ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
+ for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i)
- val tmp = buf(j)
- buf(j) = buf(i)
- buf(i) = tmp
+ val tmp = arr(j)
+ arr(j) = arr(i)
+ arr(i) = tmp
}
- buf
+ arr
}
/**
@@ -155,7 +217,7 @@ object Utils {
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
-
+
/**
* Returns a standard ThreadFactory except all threads are daemons.
*/
@@ -179,10 +241,10 @@ object Utils {
return threadPool
}
-
+
/**
- * Return the string to tell how long has passed in seconds. The passing parameter should be in
- * millisecond.
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
return " " + (System.currentTimeMillis - startTimeMs) + " ms"
@@ -294,4 +356,43 @@ object Utils {
def execute(command: Seq[String]) {
execute(command, new File("."))
}
+
+
+ /**
+ * When called inside a class in the spark package, returns the name of the user code class
+ * (outside the spark package) that called into Spark, as well as which Spark method they called.
+ * This is used, for example, to tell users where in their code each RDD got created.
+ */
+ def getSparkCallSite: String = {
+ val trace = Thread.currentThread.getStackTrace().filter( el =>
+ (!el.getMethodName.contains("getStackTrace")))
+
+ // Keep crawling up the stack trace until we find the first function not inside of the spark
+ // package. We track the last (shallowest) contiguous Spark method. This might be an RDD
+ // transformation, a SparkContext function (such as parallelize), or anything else that leads
+ // to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
+ var lastSparkMethod = "<unknown>"
+ var firstUserFile = "<unknown>"
+ var firstUserLine = 0
+ var finished = false
+
+ for (el <- trace) {
+ if (!finished) {
+ if (el.getClassName.startsWith("spark.") && !el.getClassName.startsWith("spark.examples.")) {
+ lastSparkMethod = if (el.getMethodName == "<init>") {
+ // Spark method is a constructor; get its class name
+ el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
+ } else {
+ el.getMethodName
+ }
+ }
+ else {
+ firstUserLine = el.getLineNumber
+ firstUserFile = el.getFileName
+ finished = true
+ }
+ }
+ }
+ "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
index 7c0b17c45e..843e1bd18b 100644
--- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
@@ -22,8 +22,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
import JavaDoubleRDD.fromRDD
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
// first() has to be overriden here in order for its return type to be Double instead of Object.
@@ -31,36 +36,63 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
// Transformations (return a new RDD)
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
fromRDD(srdd.filter(x => f(x).booleanValue()))
+ /**
+ * Return a sampled subset of this RDD.
+ */
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
// Double RDD functions
+ /** Return the sum of the elements in this RDD. */
def sum(): Double = srdd.sum()
+ /** Return a [[spark.StatCounter]] describing the elements in this RDD. */
def stats(): StatCounter = srdd.stats()
+ /** Return the mean of the elements in this RDD. */
def mean(): Double = srdd.mean()
+ /** Return the variance of the elements in this RDD. */
def variance(): Double = srdd.variance()
+ /** Return the standard deviation of the elements in this RDD. */
def stdev(): Double = srdd.stdev()
+ /** Return the approximate mean of the elements in this RDD. */
def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.meanApprox(timeout, confidence)
+ /** Return the approximate mean of the elements in this RDD. */
def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
+ /** Return the approximate sum of the elements in this RDD. */
def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.sumApprox(timeout, confidence)
-
+
+ /** Return the approximate sum of the elements in this RDD. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
}
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index c28a13b061..5c2be534ff 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -1,13 +1,5 @@
package spark.api.java
-import spark.SparkContext.rddToPairRDDFunctions
-import spark.api.java.function.{Function2 => JFunction2}
-import spark.api.java.function.{Function => JFunction}
-import spark.partial.BoundedDouble
-import spark.partial.PartialResult
-import spark.storage.StorageLevel
-import spark._
-
import java.util.{List => JList}
import java.util.Comparator
@@ -19,6 +11,17 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.conf.Configuration
+import spark.api.java.function.{Function2 => JFunction2}
+import spark.api.java.function.{Function => JFunction}
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
+import spark.OrderedRDDFunctions
+import spark.storage.StorageLevel
+import spark.HashPartitioner
+import spark.Partitioner
+import spark.RDD
+import spark.SparkContext.rddToPairRDDFunctions
+
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
@@ -31,21 +34,44 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
// Common RDD functions
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
// Transformations (return a new RDD)
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
+ /**
+ * Return a sampled subset of this RDD.
+ */
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
@@ -56,7 +82,21 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def first(): (K, V) = rdd.first()
// Pair RDD functions
-
+
+ /**
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
+ *
+ * In addition, users can control the partitioning of the output RDD, and whether to perform
+ * map-side aggregation (if a mapper can produce multiple items with the same key).
+ */
def combineByKey[C](createCombiner: Function[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
@@ -71,50 +111,113 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
))
}
+ /**
+ * Simplified version of combineByKey that hash-partitions the output RDD.
+ */
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
numSplits: Int): JavaPairRDD[K, C] =
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(partitioner, func))
+ /**
+ * Merge the values for each key using an associative reduce function, but return the results
+ * immediately to the master as a Map. This will also perform the merging locally on each mapper
+ * before sending results to a reducer, similarly to a "combiner" in MapReduce.
+ */
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
mapAsJavaMap(rdd.reduceByKeyLocally(func))
+ /** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ */
def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(func, numSplits))
+ /**
+ * Group the values for each key in the RDD into a single sequence. Allows controlling the
+ * partitioning of the resulting key-value pair RDD by passing a Partitioner.
+ */
def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with into `numSplits` partitions.
+ */
def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(numSplits)))
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
+ * is true, Spark will group values of the same key together on the map side before the
+ * repartitioning, to only send each key over the network once. If a large number of
+ * duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should
+ * be set to true.
+ */
def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] =
fromRDD(rdd.partitionBy(partitioner))
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, partitioner))
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other, partitioner))
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other, partitioner))
+ /**
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the default
+ * parallelism level.
+ */
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
@@ -123,40 +226,94 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
}
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
+ */
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
val partitioner = rdd.defaultPartitioner(rdd)
fromRDD(reduceByKey(partitioner, func))
}
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with the default parallelism level.
+ */
def groupByKey(): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other))
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, numSplits))
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * using the default level of parallelism.
+ */
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other))
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * into `numSplits` partitions.
+ */
def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other, numSplits))
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD using the default parallelism level.
+ */
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other))
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD into the given number of partitions.
+ */
def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other, numSplits))
+ /**
+ * Return the key-value pairs in this RDD to the master as a Map.
+ */
def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+ /**
+ * Pass each value in the key-value pair RDD through a map function without changing the keys;
+ * this also retains the original RDD's partitioning.
+ */
def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
fromRDD(rdd.mapValues(f))
}
+ /**
+ * Pass each value in the key-value pair RDD through a flatMap function without changing the
+ * keys; this also retains the original RDD's partitioning.
+ */
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
@@ -165,37 +322,68 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.flatMapValues(fn))
}
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other)))
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])]
= fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits)))
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits)))
+ /** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
+ /** Alias for cogroup. */
def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
+ /**
+ * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
+ * RDD has a known partitioner by only searching the partition that the key maps to.
+ */
def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key))
+ /** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
@@ -205,6 +393,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
+ /** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
@@ -213,6 +402,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
+ /** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
@@ -222,6 +412,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
+ /** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
@@ -230,21 +421,49 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
+ * that storage system. The JobConf should set an OutputFormat and any output paths required
+ * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
+ * MapReduce job.
+ */
def saveAsHadoopDataset(conf: JobConf) {
rdd.saveAsHadoopDataset(conf)
}
-
- // Ordered RDD Functions
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements in
+ * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
+ * ordered list of records (in the `save` case, they will be written to multiple `part-X` files
+ * in the filesystem, in order of the keys).
+ */
def sortByKey(): JavaPairRDD[K, V] = sortByKey(true)
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
sortByKey(comp, true)
}
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true)
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
class KeyOrdering(val a: K) extends Ordered[K] {
override def compare(b: K) = comp.compare(a, b)
@@ -274,4 +493,4 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index 541aa1e60b..ac31350ec3 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -11,20 +11,43 @@ JavaRDDLike[T, JavaRDD[T]] {
// Common RDD functions
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
// Transformations (return a new RDD)
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
wrapRDD(rdd.filter((x => f(x).booleanValue())))
+ /**
+ * Return a sampled subset of this RDD.
+ */
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
-
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 785dd96394..13fcee1004 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -19,41 +19,71 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def rdd: RDD[T]
+ /** Set of partitions in this RDD. */
def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
+ /** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
-
+
+ /** A unique ID for this RDD (within its SparkContext). */
def id: Int = rdd.id
+ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel: StorageLevel = rdd.getStorageLevel
+ /**
+ * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
+ * This should ''not'' be called by users directly, but is available for implementors of custom
+ * subclasses of RDD.
+ */
def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
// Transformations (return a new RDD)
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
def map[R](f: JFunction[T, R]): JavaRDD[R] =
new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
}
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
}
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
@@ -61,29 +91,50 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
JavaPairRDD[K, V] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
+ /**
+ * Return an RDD created by coalescing all elements within each partition into an array.
+ */
def glom(): JavaRDD[JList[T]] =
new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+ /**
+ * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
+ * elements (a, b) where a is in `this` and b is in `other`.
+ */
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
other.classManifest)
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
@@ -92,6 +143,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
@@ -100,56 +155,114 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm)
}
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: String): JavaRDD[String] = rdd.pipe(command)
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: JList[String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command))
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env))
// Actions (launch a job to return a value to the user program)
-
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
def foreach(f: VoidFunction[T]) {
val cleanF = rdd.context.clean(f)
rdd.foreach(cleanF)
}
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
def collect(): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.collect().toSeq
new java.util.ArrayList(arr)
}
-
+
+ /**
+ * Reduces the elements of this RDD using the specified associative binary operator.
+ */
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using a
+ * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
+ * modify t1 and return it as its result value to avoid object allocation; however, it should not
+ * modify t2.
+ */
def fold(zeroValue: T)(f: JFunction2[T, T, T]): T =
rdd.fold(zeroValue)(f)
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using
+ * given combine functions and a neutral "zero value". This function can return a different result
+ * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
+ * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U to avoid memory
+ * allocation.
+ */
def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
combOp: JFunction2[U, U, U]): U =
rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
+ /**
+ * Return the number of elements in the RDD.
+ */
def count(): Long = rdd.count()
+ /**
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
+ */
def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
rdd.countApprox(timeout, confidence)
+ /**
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
+ */
def countApprox(timeout: Long): PartialResult[BoundedDouble] =
rdd.countApprox(timeout)
+ /**
+ * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
+ * combine step happens locally on the master, equivalent to running a single reduce task.
+ */
def countByValue(): java.util.Map[T, java.lang.Long] =
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
+ /**
+ * (Experimental) Approximate version of countByValue().
+ */
def countByValueApprox(
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+ /**
+ * (Experimental) Approximate version of countByValue().
+ */
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+ /**
+ * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
+ * it will be slow if a lot of partitions are required. In that case, use collect() to get the
+ * whole RDD instead.
+ */
def take(num: Int): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.take(num).toSeq
@@ -162,9 +275,18 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
+ /**
+ * Return the first element in this RDD.
+ */
def first(): T = rdd.first()
+ /**
+ * Save this RDD as a text file, using string representations of elements.
+ */
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+ /**
+ * Save this RDD as a SequenceFile of serialized objects.
+ */
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index 4a7d945a8d..edbb187b1b 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -1,43 +1,78 @@
package spark.api.java
-import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
-import spark.SparkContext.IntAccumulatorParam
-import spark.SparkContext.DoubleAccumulatorParam
-import spark.broadcast.Broadcast
+import java.util.{Map => JMap}
+import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
-
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
+import spark.SparkContext.IntAccumulatorParam
+import spark.SparkContext.DoubleAccumulatorParam
+import spark.broadcast.Broadcast
-import scala.collection.JavaConversions
-
+/**
+ * A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and
+ * works with Java collections instead of Scala ones.
+ */
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
- def this(master: String, frameworkName: String) = this(new SparkContext(master, frameworkName))
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ */
+ def this(master: String, jobName: String) = this(new SparkContext(master, jobName))
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+ def this(master: String, jobName: String, sparkHome: String, jarFile: String) =
+ this(new SparkContext(master, jobName, sparkHome, Seq(jarFile)))
- def this(master: String, frameworkName: String, sparkHome: String, jarFile: String) =
- this(new SparkContext(master, frameworkName, sparkHome, Seq(jarFile)))
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+ def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) =
+ this(new SparkContext(master, jobName, sparkHome, jars.toSeq))
- def this(master: String, frameworkName: String, sparkHome: String, jars: Array[String]) =
- this(new SparkContext(master, frameworkName, sparkHome, jars.toSeq))
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param jobName A name for your job, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param environment Environment variables to set on worker nodes
+ */
+ def this(master: String, jobName: String, sparkHome: String, jars: Array[String],
+ environment: JMap[String, String]) =
+ this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment))
- val env = sc.env
+ private[spark] val env = sc.env
+ /** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
+ /** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T]): JavaRDD[T] =
parallelize(list, sc.defaultParallelism)
-
+ /** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
implicit val kcm: ClassManifest[K] =
@@ -47,21 +82,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
+ /** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] =
parallelizePairs(list, sc.defaultParallelism)
+ /** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD =
JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()),
numSlices))
+ /** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD =
parallelizeDoubles(list, sc.defaultParallelism)
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
def textFile(path: String): JavaRDD[String] = sc.textFile(path)
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
- /**Get an RDD for a Hadoop SequenceFile with given key and value types */
+ /**Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@@ -72,6 +118,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
+ /**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
@@ -92,6 +139,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.objectFile(path, minSplits)(cm)
}
+ /**
+ * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
+ * BytesWritable values that contain a serialized partition. This is still an experimental storage
+ * format and may not be supported exactly as is in future Spark releases. It will also be pretty
+ * slow if you use the default serializer (Java serialization), though the nice thing about it is
+ * that there's very little effort required to save arbitrary objects.
+ */
def objectFile[T](path: String): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
@@ -115,6 +169,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
+ /**
+ * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
+ * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
+ * etc).
+ */
def hadoopRDD[K, V, F <: InputFormat[K, V]](
conf: JobConf,
inputFormatClass: Class[F],
@@ -126,7 +185,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
- /**Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
@@ -139,6 +198,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
@@ -180,12 +240,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
+ /** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
implicit val cm: ClassManifest[T] = first.classManifest
sc.union(rdds)(cm)
}
+ /** Build the union of two or more RDDs. */
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
@@ -195,26 +257,49 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
+ /** Build the union of two or more RDDs. */
override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = {
val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd)
new JavaDoubleRDD(sc.union(rdds))
}
+ /**
+ * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `+=` method. Only the master can access the accumulator's `value`.
+ */
def intAccumulator(initialValue: Int): Accumulator[Int] =
sc.accumulator(initialValue)(IntAccumulatorParam)
+ /**
+ * Create an [[spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `+=` method. Only the master can access the accumulator's `value`.
+ */
def doubleAccumulator(initialValue: Double): Accumulator[Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam)
+ /**
+ * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
+ * to using the `+=` method. Only the master can access the accumulator's `value`.
+ */
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
+ /**
+ * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
+ * reading it in distributed functions. The variable will be sent to each cluster only once.
+ */
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
+ /** Shut down the SparkContext. */
def stop() {
sc.stop()
}
+ /**
+ * Get Spark's home location from either a value set through the constructor,
+ * or the spark.home Java property, or the SPARK_HOME environment variable
+ * (in that order of preference). If neither of these is set, return None.
+ */
def getSparkHome(): Option[String] = sc.getSparkHome()
}
diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java
new file mode 100644
index 0000000000..722af3c06c
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/StorageLevels.java
@@ -0,0 +1,20 @@
+package spark.api.java;
+
+import spark.storage.StorageLevel;
+
+/**
+ * Expose some commonly useful storage level constants.
+ */
+public class StorageLevels {
+ public static final StorageLevel NONE = new StorageLevel(false, false, false, 1);
+ public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1);
+ public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2);
+ public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1);
+ public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2);
+ public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1);
+ public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2);
+ public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1);
+ public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
+ public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
+ public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+}
diff --git a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java
index 7b6478c2cd..3a8192be3a 100644
--- a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java
@@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
+/**
+ * A function that returns zero or more records of type Double from each input record.
+ */
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
diff --git a/core/src/main/scala/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFunction.java
index a03a72c835..c6ef76d088 100644
--- a/core/src/main/scala/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/spark/api/java/function/DoubleFunction.java
@@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
+/**
+ * A function that returns Doubles, and can be used to construct DoubleRDDs.
+ */
// DoubleFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala
index bcba38c569..e027cdacd3 100644
--- a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala
@@ -1,5 +1,8 @@
package spark.api.java.function
+/**
+ * A function that returns zero or more output records from each input record.
+ */
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]
diff --git a/core/src/main/scala/spark/api/java/function/Function.java b/core/src/main/scala/spark/api/java/function/Function.java
index f6f2e5fd76..dae8295f21 100644
--- a/core/src/main/scala/spark/api/java/function/Function.java
+++ b/core/src/main/scala/spark/api/java/function/Function.java
@@ -8,8 +8,9 @@ import java.io.Serializable;
/**
- * Base class for functions whose return types do not have special RDDs; DoubleFunction is
- * handled separately, to allow DoubleRDDs to be constructed when mapping RDDs to doubles.
+ * Base class for functions whose return types do not create special RDDs. PairFunction and
+ * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
+ * when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;
diff --git a/core/src/main/scala/spark/api/java/function/Function2.java b/core/src/main/scala/spark/api/java/function/Function2.java
index be48b173b8..69bf12c8c9 100644
--- a/core/src/main/scala/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/spark/api/java/function/Function2.java
@@ -6,6 +6,9 @@ import scala.runtime.AbstractFunction2;
import java.io.Serializable;
+/**
+ * A two-argument function that takes arguments of type T1 and T2 and returns an R.
+ */
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
diff --git a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java
index c074b9c717..b3cc4df6aa 100644
--- a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java
@@ -7,6 +7,10 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
+/**
+ * A function that returns zero or more key-value pair records from each input record. The
+ * key-value pairs are represented as scala.Tuple2 objects.
+ */
// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and PairFlatMapFunction.
public abstract class PairFlatMapFunction<T, K, V>
diff --git a/core/src/main/scala/spark/api/java/function/PairFunction.java b/core/src/main/scala/spark/api/java/function/PairFunction.java
index 7f5bb7de13..9fc6df4b88 100644
--- a/core/src/main/scala/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/spark/api/java/function/PairFunction.java
@@ -7,6 +7,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
+/**
+ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs.
+ */
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>
diff --git a/core/src/main/scala/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/spark/api/java/function/VoidFunction.scala
index 0eefe337e8..b0096cf2bf 100644
--- a/core/src/main/scala/spark/api/java/function/VoidFunction.scala
+++ b/core/src/main/scala/spark/api/java/function/VoidFunction.scala
@@ -1,5 +1,8 @@
package spark.api.java.function
+/**
+ * A function with no return value.
+ */
// This allows Java users to write void methods without having to return Unit.
abstract class VoidFunction[T] extends Serializable {
@throws(classOf[Exception])
diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala
index d08e1e9fbf..923f5cdf4f 100644
--- a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala
+++ b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala
@@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction1
* apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
* isn't marked to allow that).
*/
-abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
+private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
@throws(classOf[Exception])
def call(t: T): R
diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala
index c9d67d9771..2c6e9b1571 100644
--- a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala
+++ b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala
@@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction2
* apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
* isn't marked to allow that).
*/
-abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
+private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
@throws(classOf[Exception])
def call(t1: T1, t2: T2): R
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index 016dc00fb0..ef27bbb502 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -11,14 +11,17 @@ import scala.math
import spark._
import spark.storage.StorageLevel
-class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging with Serializable {
+private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+ extends Broadcast[T](id)
+ with Logging
+ with Serializable {
def value = value_
+ def blockId: String = "broadcast_" + id
+
MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
- @transient var hostAddress = Utils.localIpAddress
+ @transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@@ -53,7 +56,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
- sendBroadcast
+ sendBroadcast()
}
def sendBroadcast() {
@@ -106,20 +109,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
- MultiTracker.registerBroadcast(uuid,
+ MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(uuid.toString) match {
- case Some(x) => x.asInstanceOf[T]
- case None => {
- logInfo("Started reading broadcast variable " + uuid)
+ SparkEnv.get.blockManager.getSingle(blockId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables
+ initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
@@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
- val receptionSucceeded = receiveBroadcast(uuid)
+ val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
- logError("Reading Broadcasted variable " + uuid + " failed")
+ logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable {
}
}
- def receiveBroadcast(variableUUID: UUID): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableUUID)
+ def receiveBroadcast(variableID: Long): Boolean = {
+ val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@@ -307,9 +311,11 @@ extends Broadcast[T] with Logging with Serializable {
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate =
- math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
+ var numThreadsToCreate = 0
+ listOfSources.synchronized {
+ numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
threadPool.getActiveCount
+ }
while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkToRandom
@@ -722,7 +728,6 @@ extends Broadcast[T] with Logging with Serializable {
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
- // Don't stop until there is a copy in HDFS
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
@@ -730,14 +735,17 @@ extends Broadcast[T] with Logging with Serializable {
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
- logError("GuideMultipleRequests Timeout.")
-
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
}
}
}
@@ -760,7 +768,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
- MultiTracker.unregisterBroadcast(uuid)
+ MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@@ -918,9 +926,7 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
- case e: Exception => {
- logError("ServeMultipleRequests Timeout.")
- }
+ case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection:" + clientSocket)
@@ -1023,9 +1029,12 @@ extends Broadcast[T] with Logging with Serializable {
}
}
-class BitTorrentBroadcastFactory
+private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
- def stop() = MultiTracker.stop
+ def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new BitTorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { MultiTracker.stop() }
}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index d68e56a114..6055bfd045 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -1,25 +1,20 @@
package spark.broadcast
import java.io._
-import java.net._
-import java.util.{BitSet, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
-import scala.collection.mutable.Map
+import java.util.concurrent.atomic.AtomicLong
import spark._
-trait Broadcast[T] extends Serializable {
- val uuid = UUID.randomUUID
-
+abstract class Broadcast[T](id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
- override def toString = "spark.Broadcast(" + uuid + ")"
+ override def toString = "spark.Broadcast(" + id + ")"
}
+private[spark]
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false
@@ -49,14 +44,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
broadcastFactory.stop()
}
- private def getBroadcastFactory: BroadcastFactory = {
- if (broadcastFactory == null) {
- throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
- }
- broadcastFactory
- }
+ private val nextBroadcastId = new AtomicLong(0)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
+ def newBroadcast[T](value_ : T, isLocal: Boolean) =
+ broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_
}
diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
index e341d556bf..ab6d302827 100644
--- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala
@@ -6,8 +6,8 @@ package spark.broadcast
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
-trait BroadcastFactory {
+private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
- def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 03986ea756..7eb4ddb74f 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -12,44 +12,47 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
-class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging with Serializable {
+private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
def value = value_
+ def blockId: String = "broadcast_" + id
+
HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
if (!isLocal) {
- HttpBroadcast.write(uuid, value_)
+ HttpBroadcast.write(id, value_)
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(uuid.toString) match {
+ SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
- logInfo("Started reading broadcast variable " + uuid)
+ logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
- value_ = HttpBroadcast.read[T](uuid)
- SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ value_ = HttpBroadcast.read[T](id)
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
}
}
-class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
- def stop() = HttpBroadcast.stop()
+private[spark] class HttpBroadcastFactory extends BroadcastFactory {
+ def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new HttpBroadcast[T](value_, isLocal, id)
+
+ def stop() { HttpBroadcast.stop() }
}
private object HttpBroadcast extends Logging {
@@ -65,7 +68,7 @@ private object HttpBroadcast extends Logging {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- compress = System.getProperty("spark.compress", "false").toBoolean
+ compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) {
createServer()
}
@@ -76,9 +79,12 @@ private object HttpBroadcast extends Logging {
}
def stop() {
- if (server != null) {
- server.stop()
- server = null
+ synchronized {
+ if (server != null) {
+ server.stop()
+ server = null
+ }
+ initialized = false
}
}
@@ -91,8 +97,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}
- def write(uuid: UUID, value: Any) {
- val file = new File(broadcastDir, "broadcast-" + uuid)
+ def write(id: Long, value: Any) {
+ val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else {
@@ -104,8 +110,8 @@ private object HttpBroadcast extends Logging {
serOut.close()
}
- def read[T](uuid: UUID): T = {
- val url = serverUri + "/broadcast-" + uuid
+ def read[T](id: Long): T = {
+ val url = serverUri + "/broadcast-" + id
var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else {
diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala
index d5f5b22461..5e76dedb94 100644
--- a/core/src/main/scala/spark/broadcast/MultiTracker.scala
+++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala
@@ -2,8 +2,7 @@ package spark.broadcast
import java.io._
import java.net._
-import java.util.{UUID, Random}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+import java.util.Random
import scala.collection.mutable.Map
@@ -18,7 +17,7 @@ extends Logging {
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[UUID, SourceInfo]()
+ var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
@@ -154,44 +153,44 @@ extends Logging {
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
- valueToGuideMap += (uuid -> gInfo)
+ valueToGuideMap += (id -> gInfo)
}
- logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+ logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
- valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
+ valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
- logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+ logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive UUID
- val uuid = ois.readObject.asInstanceOf[UUID]
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
var gInfo =
- if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
+ if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
- logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
+ logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
@@ -224,12 +223,12 @@ extends Logging {
}
}
- def getGuideInfo(variableUUID: UUID): SourceInfo = {
+ def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault)
+ var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
@@ -247,8 +246,8 @@ extends Logging {
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
- // Send UUID and receive GuideInfo
- oosTracker.writeObject(variableUUID)
+ // Send Long and receive GuideInfo
+ oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
@@ -276,7 +275,7 @@ extends Logging {
return gInfo
}
- def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
+ def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@@ -286,8 +285,8 @@ extends Logging {
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
- // Send UUID of this broadcast
- oosST.writeObject(uuid)
+ // Send Long of this broadcast
+ oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
@@ -303,7 +302,7 @@ extends Logging {
socket.close()
}
- def unregisterBroadcast(uuid: UUID) {
+ def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@@ -313,8 +312,8 @@ extends Logging {
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
- // Send UUID of this broadcast
- oosST.writeObject(uuid)
+ // Send Long of this broadcast
+ oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away
@@ -383,10 +382,10 @@ extends Logging {
}
}
-case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
+private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
-case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
+private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala
index f90385fd47..c79bb93c38 100644
--- a/core/src/main/scala/spark/broadcast/SourceInfo.scala
+++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala
@@ -7,7 +7,7 @@ import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*/
-case class SourceInfo (hostAddress: String,
+private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
@@ -26,10 +26,11 @@ extends Comparable[SourceInfo] with Logging {
/**
* Helper Object of SourceInfo for its constants
*/
-object SourceInfo {
- // Constants for special values of listenPort
+private[spark] object SourceInfo {
+ // Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
- val TxOverGoToDefault = 0
+ // Broadcast has already finished. Try default mechanism.
+ val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index c9e1e67d87..fa676e9064 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -10,14 +10,15 @@ import scala.math
import spark._
import spark.storage.StorageLevel
-class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
-extends Broadcast[T] with Logging with Serializable {
+private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
def value = value_
+ def blockId = "broadcast_" + id
+
MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable {
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
- @transient var hostAddress = Utils.localIpAddress
+ @transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
- sendBroadcast
+ sendBroadcast()
}
def sendBroadcast() {
@@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
- MultiTracker.registerBroadcast(uuid,
+ MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(uuid.toString) match {
- case Some(x) => x.asInstanceOf[T]
- case None => {
- logInfo("Started reading broadcast variable " + uuid)
+ SparkEnv.get.blockManager.getSingle(blockId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables
+ initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
@@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
- val receptionSucceeded = receiveBroadcast(uuid)
+ val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
- uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
- logError("Reading Broadcasted variable " + uuid + " failed")
+ logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
- logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
- }
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
- hostAddress = Utils.localIpAddress
+ hostAddress = Utils.localIpAddress()
listenPort = -1
stopBroadcast = false
}
- def receiveBroadcast(variableUUID: UUID): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableUUID)
+ def receiveBroadcast(variableID: Long): Boolean = {
+ val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@@ -290,15 +292,17 @@ extends Broadcast[T] with Logging with Serializable {
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
- logError("GuideMultipleRequests Timeout.")
-
// Stop broadcast if at least one worker has connected and
- // everyone connected so far are done.
- // Comparing with listOfSources.size - 1, because the Guide itself
- // is included
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
}
}
}
@@ -316,7 +320,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
- MultiTracker.unregisterBroadcast(uuid)
+ MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@@ -490,7 +494,7 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
- case e: Exception => logError("ServeMultipleRequests Timeout.")
+ case e: Exception => { }
}
if (clientSocket != null) {
@@ -570,9 +574,12 @@ extends Broadcast[T] with Logging with Serializable {
}
}
-class TreeBroadcastFactory
+private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
- def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
- def stop() = MultiTracker.stop
+ def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TreeBroadcast[T](value_, isLocal, id)
+
+ def stop() { MultiTracker.stop() }
}
diff --git a/core/src/main/scala/spark/deploy/Command.scala b/core/src/main/scala/spark/deploy/Command.scala
index 344888919a..577101e3c3 100644
--- a/core/src/main/scala/spark/deploy/Command.scala
+++ b/core/src/main/scala/spark/deploy/Command.scala
@@ -2,7 +2,7 @@ package spark.deploy
import scala.collection.Map
-case class Command(
+private[spark] case class Command(
mainClass: String,
arguments: Seq[String],
environment: Map[String, String]) {
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 141bbe4d57..d2b63d6e0d 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -7,13 +7,15 @@ import scala.collection.immutable.List
import scala.collection.mutable.HashMap
-sealed trait DeployMessage extends Serializable
+private[spark] sealed trait DeployMessage extends Serializable
// Worker to Master
+private[spark]
case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int)
extends DeployMessage
+private[spark]
case class ExecutorStateChanged(
jobId: String,
execId: Int,
@@ -23,11 +25,11 @@ case class ExecutorStateChanged(
// Master to Worker
-case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
-case class RegisterWorkerFailed(message: String) extends DeployMessage
-case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
+private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
+private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
+private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
-case class LaunchExecutor(
+private[spark] case class LaunchExecutor(
jobId: String,
execId: Int,
jobDesc: JobDescription,
@@ -38,33 +40,42 @@ case class LaunchExecutor(
// Client to Master
-case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
+private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
// Master to Client
+private[spark]
case class RegisteredJob(jobId: String) extends DeployMessage
+
+private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
+
+private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
+
+private[spark]
case class JobKilled(message: String)
// Internal message in Client
-case object StopClient
+private[spark] case object StopClient
// MasterWebUI To Master
-case object RequestMasterState
+private[spark] case object RequestMasterState
// Master to MasterWebUI
+private[spark]
case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo],
completedJobs: List[JobInfo])
// WorkerWebUI to Worker
-case object RequestWorkerState
+private[spark] case object RequestWorkerState
// Worker to WorkerWebUI
+private[spark]
case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) \ No newline at end of file
diff --git a/core/src/main/scala/spark/deploy/ExecutorState.scala b/core/src/main/scala/spark/deploy/ExecutorState.scala
index d6ff1c54ca..5dc0c54552 100644
--- a/core/src/main/scala/spark/deploy/ExecutorState.scala
+++ b/core/src/main/scala/spark/deploy/ExecutorState.scala
@@ -1,6 +1,6 @@
package spark.deploy
-object ExecutorState
+private[spark] object ExecutorState
extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 8ae77b1038..20879c5f11 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -1,6 +1,6 @@
package spark.deploy
-class JobDescription(
+private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
new file mode 100644
index 0000000000..8b2a71add5
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -0,0 +1,58 @@
+package spark.deploy
+
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+
+import spark.deploy.worker.Worker
+import spark.deploy.master.Master
+import spark.util.AkkaUtils
+import spark.{Logging, Utils}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[spark]
+class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
+
+ val localIpAddress = Utils.localIpAddress
+
+ var masterActor : ActorRef = _
+ var masterActorSystem : ActorSystem = _
+ var masterPort : Int = _
+ var masterUrl : String = _
+
+ val slaveActorSystems = ArrayBuffer[ActorSystem]()
+ val slaveActors = ArrayBuffer[ActorRef]()
+
+ def start() : String = {
+ logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
+
+ /* Start the Master */
+ val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
+ masterActorSystem = actorSystem
+ masterUrl = "spark://" + localIpAddress + ":" + masterPort
+ val actor = masterActorSystem.actorOf(
+ Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
+ masterActor = actor
+
+ /* Start the Slaves */
+ for (slaveNum <- 1 to numSlaves) {
+ val (actorSystem, boundPort) =
+ AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
+ slaveActorSystems += actorSystem
+ val actor = actorSystem.actorOf(
+ Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
+ name = "Worker")
+ slaveActors += actor
+ }
+
+ return masterUrl
+ }
+
+ def stop() {
+ logInfo("Shutting down local Spark cluster.")
+ // Stop the slaves before the master so they don't get upset that it disconnected
+ slaveActorSystems.foreach(_.shutdown())
+ slaveActorSystems.foreach(_.awaitTermination())
+ masterActorSystem.shutdown()
+ masterActorSystem.awaitTermination()
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index c7fa8a3874..e51b0c5c15 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -4,6 +4,7 @@ import spark.deploy._
import akka.actor._
import akka.pattern.ask
import akka.util.duration._
+import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
@@ -16,7 +17,7 @@ import akka.dispatch.Await
* The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
* and a listener for job events, and calls back the listener when various events occur.
*/
-class Client(
+private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
jobDescription: JobDescription,
@@ -42,7 +43,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
- //master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
@@ -101,9 +101,13 @@ class Client(
def stop() {
if (actor != null) {
- val timeout = 1.seconds
- val future = actor.ask(StopClient)(timeout)
- Await.result(future, timeout)
+ try {
+ val timeout = 1.seconds
+ val future = actor.ask(StopClient)(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: AskTimeoutException => // Ignore it, maybe master went away
+ }
actor = null
}
}
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index 7d23baff32..a8fa982085 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -7,7 +7,7 @@ package spark.deploy.client
*
* Users of this API should *not* block inside the callback methods.
*/
-trait ClientListener {
+private[spark] trait ClientListener {
def connected(jobId: String): Unit
def disconnected(): Unit
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index df9a36c7fe..bf0e7428ba 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -4,7 +4,7 @@ import spark.util.AkkaUtils
import spark.{Logging, Utils}
import spark.deploy.{Command, JobDescription}
-object TestClient {
+private[spark] object TestClient {
class TestListener extends ClientListener with Logging {
def connected(id: String) {
diff --git a/core/src/main/scala/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/spark/deploy/client/TestExecutor.scala
index 2e40e10d18..0e46db2272 100644
--- a/core/src/main/scala/spark/deploy/client/TestExecutor.scala
+++ b/core/src/main/scala/spark/deploy/client/TestExecutor.scala
@@ -1,6 +1,6 @@
package spark.deploy.client
-object TestExecutor {
+private[spark] object TestExecutor {
def main(args: Array[String]) {
println("Hello world!")
while (true) {
diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
index 335e00958c..1db2c32633 100644
--- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
@@ -2,7 +2,7 @@ package spark.deploy.master
import spark.deploy.ExecutorState
-class ExecutorInfo(
+private[spark] class ExecutorInfo(
val id: Int,
val job: JobInfo,
val worker: WorkerInfo,
diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala
index 31d48b82b9..8795c09cc1 100644
--- a/core/src/main/scala/spark/deploy/master/JobInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala
@@ -5,6 +5,7 @@ import java.util.Date
import akka.actor.ActorRef
import scala.collection.mutable
+private[spark]
class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) {
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@@ -31,4 +32,13 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va
}
def coresLeft: Int = desc.cores - coresGranted
+
+ private var _retryCount = 0
+
+ def retryCount = _retryCount
+
+ def incrementRetryCount = {
+ _retryCount += 1
+ _retryCount
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala
index 50b0c6f95b..2b70cf0191 100644
--- a/core/src/main/scala/spark/deploy/master/JobState.scala
+++ b/core/src/main/scala/spark/deploy/master/JobState.scala
@@ -1,7 +1,9 @@
package spark.deploy.master
-object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
type JobState = Value
val WAITING, RUNNING, FINISHED, FAILED = Value
+
+ val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index c98dddea7b..6010f7cff2 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -1,21 +1,20 @@
package spark.deploy.master
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-
import akka.actor._
-import spark.{Logging, Utils}
-import spark.util.AkkaUtils
+import akka.actor.Terminated
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+
import java.text.SimpleDateFormat
import java.util.Date
-import akka.remote.RemoteClientLifeCycleEvent
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
import spark.deploy._
-import akka.remote.RemoteClientShutdown
-import akka.remote.RemoteClientDisconnected
-import spark.deploy.RegisterWorker
-import spark.deploy.RegisterWorkerFailed
-import akka.actor.Terminated
+import spark.{Logging, SparkException, Utils}
+import spark.util.AkkaUtils
-class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
+
+private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
var nextJobNumber = 0
@@ -81,12 +80,22 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message)
if (ExecutorState.isFinished(state)) {
+ val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
logInfo("Removing executor " + exec.fullId + " because it is " + state)
- idToJob(jobId).removeExecutor(exec)
+ jobInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec)
- // TODO: the worker would probably want to restart the executor a few times
- schedule()
+
+ // Only retry certain number of times so we don't go into an infinite loop.
+ if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
+ schedule()
+ } else {
+ val e = new SparkException("Job %s wth ID %s failed %d times.".format(
+ jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
+ logError(e.getMessage, e)
+ throw e
+ //System.exit(1)
+ }
}
}
case None =>
@@ -112,7 +121,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
addressToWorker.get(address).foreach(removeWorker)
addressToJob.get(address).foreach(removeJob)
}
-
+
case RequestMasterState => {
sender ! MasterState(ip + ":" + port, workers.toList, jobs.toList, completedJobs.toList)
}
@@ -203,7 +212,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
}
}
-object Master {
+private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
index 0f7a92bdd0..1b1c3dd0ad 100644
--- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
@@ -6,7 +6,7 @@ import spark.Utils
/**
* Command-line parser for the master.
*/
-class MasterArguments(args: Array[String]) {
+private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 7077
var webUiPort = 8080
@@ -51,7 +51,7 @@ class MasterArguments(args: Array[String]) {
*/
def printUsageAndExit(exitCode: Int) {
System.err.println(
- "Usage: spark-master [options]\n" +
+ "Usage: Master [options]\n" +
"\n" +
"Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index f03c0a0229..700a41c770 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -10,6 +10,7 @@ import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy._
+private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@@ -22,7 +23,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
completeWith {
val future = master ? RequestMasterState
future.map {
- masterState => masterui.html.index.render(masterState.asInstanceOf[MasterState])
+ masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
}
}
} ~
@@ -36,7 +37,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
// A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference.
(masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match {
- case Some(job) => masterui.html.job_details.render(job)
+ case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
}
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 59474a0945..16b3f9b653 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
-class WorkerInfo(
+private[spark] class WorkerInfo(
val id: String,
val host: String,
val port: Int,
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 3e24380810..07ae7bca78 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -13,7 +13,7 @@ import spark.deploy.ExecutorStateChanged
/**
* Manages the execution of one executor process.
*/
-class ExecutorRunner(
+private[spark] class ExecutorRunner(
val jobId: String,
val execId: Int,
val jobDesc: JobDescription,
@@ -29,12 +29,25 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId
var workerThread: Thread = null
var process: Process = null
+ var shutdownHook: Thread = null
def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) {
override def run() { fetchAndRunExecutor() }
}
workerThread.start()
+
+ // Shutdown hook that kills actors on shutdown.
+ shutdownHook = new Thread() {
+ override def run() {
+ if (process != null) {
+ logInfo("Shutdown hook killing child process.")
+ process.destroy()
+ process.waitFor()
+ }
+ }
+ }
+ Runtime.getRuntime.addShutdownHook(shutdownHook)
}
/** Stop this executor runner, including killing the process it launched */
@@ -45,40 +58,10 @@ class ExecutorRunner(
if (process != null) {
logInfo("Killing process!")
process.destroy()
+ process.waitFor()
}
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
- }
- }
-
- /**
- * Download a file requested by the executor. Supports fetching the file in a variety of ways,
- * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
- */
- def fetchFile(url: String, targetDir: File) {
- val filename = url.split("/").last
- val targetFile = new File(targetDir, filename)
- if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
- // Use the java.net library to fetch it
- logInfo("Fetching " + url + " to " + targetFile)
- val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- } else {
- // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val uri = new URI(url)
- val conf = new Configuration()
- val fs = FileSystem.get(uri, conf)
- val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- }
- // Decompress the file if it's a .tar or .tar.gz
- if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xzf", filename), targetDir)
- } else if (filename.endsWith(".tar")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -92,7 +75,8 @@ class ExecutorRunner(
def buildCommandSeq(): Seq[String] = {
val command = jobDesc.command
- val runScript = new File(sparkHome, "run").getCanonicalPath
+ val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
+ val runScript = new File(sparkHome, script).getCanonicalPath
Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
}
@@ -101,7 +85,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file)
new Thread("redirect output to " + file) {
override def run() {
- Utils.copyStream(in, out, true)
+ try {
+ Utils.copyStream(in, out, true)
+ } catch {
+ case e: IOException =>
+ logInfo("Redirection to " + file + " closed: " + e.getMessage)
+ }
}
}.start()
}
@@ -131,6 +120,9 @@ class ExecutorRunner(
}
env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString)
+ // In case we are running this from within the Spark Shell, avoid creating a "scala"
+ // parent process for the executor command
+ env.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
// Redirect its stdout and stderr to files
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 0a80463c0b..474c9364fd 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -16,7 +16,14 @@ import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import java.io.File
-class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String)
+private[spark] class Worker(
+ ip: String,
+ port: Int,
+ webUiPort: Int,
+ cores: Int,
+ memory: Int,
+ masterUrl: String,
+ workDirPath: String = null)
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
@@ -37,7 +44,11 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
- workDir = new File(sparkHome, "work")
+ workDir = if (workDirPath != null) {
+ new File(workDirPath)
+ } else {
+ new File(sparkHome, "work")
+ }
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@@ -153,14 +164,19 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
}
+
+ override def postStop() {
+ executors.values.foreach(_.kill())
+ }
}
-object Worker {
+private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
- Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)),
+ Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
+ args.master, args.workDir)),
name = "Worker")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 1efe8304ea..60dc107a4c 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -8,13 +8,14 @@ import java.lang.management.ManagementFactory
/**
* Command-line parser for the master.
*/
-class WorkerArguments(args: Array[String]) {
+private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
var master: String = null
+ var workDir: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@@ -29,6 +30,9 @@ class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
}
+ if (System.getenv("SPARK_WORKER_DIR") != null) {
+ workDir = System.getenv("SPARK_WORKER_DIR")
+ }
parse(args.toList)
@@ -49,6 +53,10 @@ class WorkerArguments(args: Array[String]) {
memory = value
parse(tail)
+ case ("--work-dir" | "-d") :: value :: tail =>
+ workDir = value
+ parse(tail)
+
case "--webui-port" :: IntParam(value) :: tail =>
webUiPort = value
parse(tail)
@@ -77,13 +85,14 @@ class WorkerArguments(args: Array[String]) {
*/
def printUsageAndExit(exitCode: Int) {
System.err.println(
- "Usage: spark-worker [options] <master>\n" +
+ "Usage: Worker [options] <master>\n" +
"\n" +
"Master must be a URL of the form spark://hostname:port\n" +
"\n" +
"Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
+ " -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index 58a05e1a38..d06f4884ee 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -9,6 +9,7 @@ import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
+private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@@ -21,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
completeWith{
val future = worker ? RequestWorkerState
future.map { workerState =>
- workerui.html.index(workerState.asInstanceOf[WorkerState])
+ spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
}
}
} ~
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index dba209ac27..dfdb22024e 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -1,10 +1,12 @@
package spark.executor
import java.io.{File, FileOutputStream}
-import java.net.{URL, URLClassLoader}
+import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
-import scala.collection.mutable.ArrayBuffer
+import org.apache.hadoop.fs.FileUtil
+
+import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
@@ -14,11 +16,16 @@ import java.nio.ByteBuffer
/**
* The Mesos executor for Spark.
*/
-class Executor extends Logging {
- var classLoader: ClassLoader = null
+private[spark] class Executor extends Logging {
+ var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
+
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
initLogging()
@@ -32,14 +39,14 @@ class Executor extends Logging {
System.setProperty(key, value)
}
+ // Create our ClassLoader and set it on this thread
+ urlClassLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
+
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
- // Create our ClassLoader (using spark properties) and set it on this thread
- classLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(classLoader)
-
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
@@ -54,15 +61,16 @@ class Executor extends Logging {
override def run() {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
- val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ updateDependencies(taskFiles, taskJars)
+ val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@@ -96,25 +104,15 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
- private def createClassLoader(): ClassLoader = {
+ private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader
- // If any JAR URIs are given through spark.jar.uris, fetch them to the
- // current directory and put them all on the classpath. We assume that
- // each URL has a unique file name so that no local filenames will clash
- // in this process. This is guaranteed by ClusterScheduler.
- val uris = System.getProperty("spark.jar.uris", "")
- val localFiles = ArrayBuffer[String]()
- for (uri <- uris.split(",").filter(_.size > 0)) {
- val url = new URL(uri)
- val filename = url.getPath.split("/").last
- downloadFile(url, filename)
- localFiles += filename
- }
- if (localFiles.size > 0) {
- val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
- loader = new URLClassLoader(urls, loader)
- }
+ // For each of the jars in the jarSet, add them to the class loader.
+ // We assume each of the files has already been fetched.
+ val urls = currentJars.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
@@ -133,13 +131,31 @@ class Executor extends Logging {
}
}
- return loader
+ return new ExecutorURLClassLoader(Array(), loader)
}
- // Download a file from a given URL to the local filesystem
- private def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(".", localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
+ }
}
}
diff --git a/core/src/main/scala/spark/executor/ExecutorBackend.scala b/core/src/main/scala/spark/executor/ExecutorBackend.scala
index 24c8776f31..e97e509700 100644
--- a/core/src/main/scala/spark/executor/ExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/ExecutorBackend.scala
@@ -6,6 +6,6 @@ import spark.TaskState.TaskState
/**
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
-trait ExecutorBackend {
+private[spark] trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
}
diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala
new file mode 100644
index 0000000000..5beb4d049e
--- /dev/null
+++ b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala
@@ -0,0 +1,14 @@
+package spark.executor
+
+import java.net.{URLClassLoader, URL}
+
+/**
+ * The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
+ */
+private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends URLClassLoader(urls, parent) {
+
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+}
diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
index 50f4e41ede..eeab3959c6 100644
--- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala
@@ -8,7 +8,7 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging}
import spark.TaskState
-class MesosExecutorBackend(executor: Executor)
+private[spark] class MesosExecutorBackend(executor: Executor)
extends MesosExecutor
with ExecutorBackend
with Logging {
@@ -59,7 +59,7 @@ class MesosExecutorBackend(executor: Executor)
/**
* Entry point for Mesos executor.
*/
-object MesosExecutorBackend {
+private[spark] object MesosExecutorBackend {
def main(args: Array[String]) {
MesosNativeLibrary.load()
// Create a new Executor and start it running
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 26b163de0a..915f71ba9f 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -14,7 +14,7 @@ import spark.scheduler.cluster.RegisterSlaveFailed
import spark.scheduler.cluster.RegisterSlave
-class StandaloneExecutorBackend(
+private[spark] class StandaloneExecutorBackend(
executor: Executor,
masterUrl: String,
slaveId: String,
@@ -62,7 +62,7 @@ class StandaloneExecutorBackend(
}
}
-object StandaloneExecutorBackend {
+private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index da8aff9dd5..80262ab7b4 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -11,6 +11,7 @@ import java.nio.channels.spi._
import java.net._
+private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
channel.configureBlocking(false)
@@ -23,8 +24,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onExceptionCallback: (Connection, Exception) => Unit = null
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
- lazy val remoteAddress = getRemoteAddress()
- lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
+ val remoteAddress = getRemoteAddress()
+ val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@@ -39,7 +40,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
def close() {
- key.cancel()
+ val k = key()
+ if (k != null) {
+ k.cancel()
+ }
channel.close()
callOnCloseCallback()
}
@@ -99,7 +103,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
-class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
+private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) {
class Outbox(fair: Int = 0) {
@@ -134,9 +138,12 @@ extends Connection(SocketChannel.open, selector_) {
if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true
return chunk
+ } else {
+ /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ "] in " + message.timeTaken )
}
- /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
}
}
None
@@ -159,10 +166,11 @@ extends Connection(SocketChannel.open, selector_) {
}
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk
- }
- /*messages -= message*/
- message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
+ } else {
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ "] in " + message.timeTaken )
+ }
}
}
None
@@ -216,7 +224,7 @@ extends Connection(SocketChannel.open, selector_) {
while(true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
- outbox.getChunk match {
+ outbox.getChunk() match {
case Some(chunk) => {
currentBuffers ++= chunk.buffers
}
@@ -252,7 +260,7 @@ extends Connection(SocketChannel.open, selector_) {
}
-class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
class Inbox() {
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 2bb5f5fc6b..da39108164 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -18,17 +18,17 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
-case class ConnectionManagerId(host: String, port: Int) {
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
}
-object ConnectionManagerId {
+private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
}
}
-class ConnectionManager(port: Int) extends Logging {
+private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
val message: Message,
@@ -113,7 +113,7 @@ class ConnectionManager(port: Int) extends Logging {
val selectedKeysCount = selector.select()
if (selectedKeysCount == 0) {
- logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+ logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
if (selectorThread.isInterrupted) {
logInfo("Selector thread was interrupted!")
@@ -167,7 +167,6 @@ class ConnectionManager(port: Int) extends Logging {
}
def removeConnection(connection: Connection) {
- /*logInfo("Removing connection")*/
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -235,7 +234,7 @@ class ConnectionManager(port: Int) extends Logging {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -276,15 +275,15 @@ class ConnectionManager(port: Int) extends Logging {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
- logWarning("Not calling back as callback is null")
+ logDebug("Not calling back as callback is null")
None
}
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
- logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
- logWarning("Response to " + bufferMessage + " does not have ack id set")
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
@@ -349,7 +348,7 @@ class ConnectionManager(port: Int) extends Logging {
}
-object ConnectionManager {
+private[spark] object ConnectionManager {
def main(args: Array[String]) {
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
index 555b3454ee..47ceaf3c07 100644
--- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -11,7 +11,7 @@ import java.net.InetAddress
import akka.dispatch.Await
import akka.util.duration._
-object ConnectionManagerTest extends Logging{
+private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 2e85803679..525751b5bf 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -7,8 +7,9 @@ import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer
import java.net.InetAddress
import java.net.InetSocketAddress
+import storage.BlockManager
-class MessageChunkHeader(
+private[spark] class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
@@ -36,7 +37,7 @@ class MessageChunkHeader(
" and sizes " + totalSize + " / " + chunkSize + " bytes"
}
-class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
@@ -50,7 +51,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
}
-abstract class Message(val typ: Long, val id: Int) {
+private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
var started = false
var startTime = -1L
@@ -64,10 +65,10 @@ abstract class Message(val typ: Long, val id: Int) {
def timeTaken(): String = (finishTime - startTime).toString + " ms"
- override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+ override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
-class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
@@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate
+ buffer.duplicate()
} else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
}
@@ -147,11 +149,10 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
}
-
}
}
-object MessageChunkHeader {
+private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
@@ -172,7 +173,7 @@ object MessageChunkHeader {
}
}
-object Message {
+private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
var lastId = 1
diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala
index e1ba7c06c0..a174d5f403 100644
--- a/core/src/main/scala/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/spark/network/ReceiverTest.scala
@@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
-object ReceiverTest {
+private[spark] object ReceiverTest {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala
index 4ab6dd3414..a4ff69e4d2 100644
--- a/core/src/main/scala/spark/network/SenderTest.scala
+++ b/core/src/main/scala/spark/network/SenderTest.scala
@@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
-object SenderTest {
+private[spark] object SenderTest {
def main(args: Array[String]) {
diff --git a/core/src/main/scala/spark/package.scala b/core/src/main/scala/spark/package.scala
new file mode 100644
index 0000000000..389ec4da3e
--- /dev/null
+++ b/core/src/main/scala/spark/package.scala
@@ -0,0 +1,15 @@
+/**
+ * Core Spark functionality. [[spark.SparkContext]] serves as the main entry point to Spark, while
+ * [[spark.RDD]] is the data type representing a distributed collection, and provides most
+ * parallel operations.
+ *
+ * In addition, [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value
+ * pairs, such as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations
+ * available only on RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations
+ * available on RDDs that can be saved as SequenceFiles. These operations are automatically
+ * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when
+ * you `import spark.SparkContext._`.
+ */
+package object spark {
+ // For package docs only
+}
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index e6535836ab..42f46e06ed 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -12,7 +12,7 @@ import spark.scheduler.JobListener
* a result of type U for each partition, and that the action returns a partial or complete result
* of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
*/
-class ApproximateActionListener[T, U, R](
+private[spark] class ApproximateActionListener[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
index 4772e43ef0..75713b2eaa 100644
--- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
+++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
@@ -4,7 +4,7 @@ package spark.partial
* An object that computes a function incrementally by merging in results of type U from multiple
* tasks. Allows partial evaluation at any point by calling currentResult().
*/
-trait ApproximateEvaluator[U, R] {
+private[spark] trait ApproximateEvaluator[U, R] {
def merge(outputId: Int, taskResult: U): Unit
def currentResult(): R
}
diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala
index 1bc90d6b39..daf2c5170c 100644
--- a/core/src/main/scala/spark/partial/CountEvaluator.scala
+++ b/core/src/main/scala/spark/partial/CountEvaluator.scala
@@ -8,7 +8,7 @@ import cern.jet.stat.Probability
* TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
* be best to make this a special case of GroupedCountEvaluator with one group.
*/
-class CountEvaluator(totalOutputs: Int, confidence: Double)
+private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[Long, BoundedDouble] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
index 3e631c0efc..01fbb8a11b 100644
--- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
@@ -14,7 +14,7 @@ import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
/**
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
*/
-class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
+private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
index 2a9ccba205..c622df5220 100644
--- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
+++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
@@ -12,7 +12,7 @@ import spark.util.StatCounter
/**
* An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
*/
-class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
+private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
index 6a2ec7a7bd..20fa55cff2 100644
--- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
+++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
@@ -12,7 +12,7 @@ import spark.util.StatCounter
/**
* An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
*/
-class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
+private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala
index b8c7cb8863..762c85400d 100644
--- a/core/src/main/scala/spark/partial/MeanEvaluator.scala
+++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala
@@ -7,7 +7,7 @@ import spark.util.StatCounter
/**
* An ApproximateEvaluator for means.
*/
-class MeanEvaluator(totalOutputs: Int, confidence: Double)
+private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala
index 6263ee3518..443abba5cd 100644
--- a/core/src/main/scala/spark/partial/StudentTCacher.scala
+++ b/core/src/main/scala/spark/partial/StudentTCacher.scala
@@ -7,7 +7,7 @@ import cern.jet.stat.Probability
* and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
* confidence intervals for many keys.
*/
-class StudentTCacher(confidence: Double) {
+private[spark] class StudentTCacher(confidence: Double) {
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala
index 0357a6bff8..58fb60f441 100644
--- a/core/src/main/scala/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/spark/partial/SumEvaluator.scala
@@ -9,7 +9,7 @@ import spark.util.StatCounter
* together, then uses the formula for the variance of two independent random variables to get
* a variance for the result and compute a confidence interval.
*/
-class SumEvaluator(totalOutputs: Int, confidence: Double)
+private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
var outputsMerged = 0
diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index daabc0d566..cb73976aed 100644
--- a/core/src/main/scala/spark/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,12 +1,18 @@
-package spark
+package spark.rdd
import scala.collection.mutable.HashMap
-class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
+import spark.Dependency
+import spark.RDD
+import spark.SparkContext
+import spark.SparkEnv
+import spark.Split
+
+private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
}
-
+private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc) {
diff --git a/core/src/main/scala/spark/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index e26041555a..7c354b6b2e 100644
--- a/core/src/main/scala/spark/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,9 +1,16 @@
-package spark
+package spark.rdd
+import spark.NarrowDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
override val index: Int = idx
}
+private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
rdd1: RDD[T],
@@ -44,4 +51,4 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
}
)
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 6959917d14..50bec9e63b 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,21 +1,29 @@
-package spark
+package spark.rdd
-import java.net.URL
-import java.io.EOFException
-import java.io.ObjectInputStream
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-sealed trait CoGroupSplitDep extends Serializable
-case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
-case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+import spark.Aggregator
+import spark.Dependency
+import spark.Logging
+import spark.OneToOneDependency
+import spark.Partitioner
+import spark.RDD
+import spark.ShuffleDependency
+import spark.SparkEnv
+import spark.Split
+private[spark] sealed trait CoGroupSplitDep extends Serializable
+private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+
+private[spark]
class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
-class CoGroupAggregator
+private[spark] class CoGroupAggregator
extends Aggregator[Any, Any, ArrayBuffer[Any]](
{ x => ArrayBuffer(x) },
{ (b, x) => b += x },
@@ -31,13 +39,13 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
- if (rdd.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + rdd)
- deps += new OneToOneDependency(rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
+ if (mapSideCombinedRDD.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
+ deps += new OneToOneDependency(mapSideCombinedRDD)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](
- context.newShuffleId, rdd, aggr, part)
+ deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
deps.toList
@@ -50,7 +58,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
dependencies(j) match {
- case s: ShuffleDependency[_, _, _] =>
+ case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
@@ -82,13 +90,13 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
- def mergePair(k: K, vs: Seq[Any]) {
- val mySeq = getSeq(k)
- for (v <- vs)
+ def mergePair(pair: (K, Seq[Any])) {
+ val mySeq = getSeq(pair._1)
+ for (v <- pair._2)
mySeq(depNum) += v
}
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, mergePair)
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
}
}
map.iterator
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
new file mode 100644
index 0000000000..0967f4f5df
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -0,0 +1,47 @@
+package spark.rdd
+
+import spark.NarrowDependency
+import spark.RDD
+import spark.Split
+
+private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+
+/**
+ * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
+ * this RDD computes one or more of the parent ones. Will produce exactly `maxPartitions` if the
+ * parent had more than this many partitions, or fewer if the parent had fewer.
+ *
+ * This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
+ * or to avoid having a large number of small tasks when processing a directory with many files.
+ */
+class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
+ extends RDD[T](prev.context) {
+
+ @transient val splits_ : Array[Split] = {
+ val prevSplits = prev.splits
+ if (prevSplits.length < maxPartitions) {
+ prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+ } else {
+ (0 until maxPartitions).map { i =>
+ val rangeStart = (i * prevSplits.length) / maxPartitions
+ val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
+ new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+ }.toArray
+ }
+ }
+
+ override def splits = splits_
+
+ override def compute(split: Split): Iterator[T] = {
+ split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
+ parentSplit => prev.iterator(parentSplit)
+ }
+ }
+
+ val dependencies = List(
+ new NarrowDependency(prev) {
+ def getParents(id: Int): Seq[Int] =
+ splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+ }
+ )
+}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
new file mode 100644
index 0000000000..dfe9dc73f3
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -0,0 +1,12 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).filter(f)
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
new file mode 100644
index 0000000000..3534dc8057
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: T => TraversableOnce[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).flatMap(f)
+}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
new file mode 100644
index 0000000000..e30564f2da
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -0,0 +1,12 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 0befca582d..bf29a1f075 100644
--- a/core/src/main/scala/spark/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.EOFException
import java.util.NoSuchElementException
@@ -15,10 +15,16 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
+import spark.Dependency
+import spark.RDD
+import spark.SerializableWritable
+import spark.SparkContext
+import spark.Split
+
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
-class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
+private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split
with Serializable {
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
new file mode 100644
index 0000000000..a904ef62c3
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -0,0 +1,19 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false)
+ extends RDD[U](prev.context) {
+
+ override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(prev.iterator(split))
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
new file mode 100644
index 0000000000..adc541694e
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -0,0 +1,21 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+private[spark]
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U])
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = f(split.index, prev.iterator(split))
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
new file mode 100644
index 0000000000..59bedad8ef
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.OneToOneDependency
+import spark.RDD
+import spark.Split
+
+private[spark]
+class MappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: T => U)
+ extends RDD[U](prev.context) {
+
+ override def splits = prev.splits
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = prev.iterator(split).map(f)
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 14f708a3f8..7a1a0fb87d 100644
--- a/core/src/main/scala/spark/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -1,18 +1,19 @@
-package spark
+package spark.rdd
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapreduce.InputFormat
-import org.apache.hadoop.mapreduce.InputSplit
-import org.apache.hadoop.mapreduce.JobContext
-import org.apache.hadoop.mapreduce.JobID
-import org.apache.hadoop.mapreduce.RecordReader
-import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.hadoop.mapreduce.TaskAttemptID
+import org.apache.hadoop.mapreduce._
import java.util.Date
import java.text.SimpleDateFormat
+import spark.Dependency
+import spark.RDD
+import spark.SerializableWritable
+import spark.SparkContext
+import spark.Split
+
+private[spark]
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
@@ -26,7 +27,8 @@ class NewHadoopRDD[K, V](
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K], valueClass: Class[V],
@transient conf: Configuration)
- extends RDD[(K, V)](sc) {
+ extends RDD[(K, V)](sc)
+ with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@@ -43,7 +45,7 @@ class NewHadoopRDD[K, V](
@transient
private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
- val jobContext = new JobContext(conf, jobId)
+ val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Split](rawSplits.size)
for (i <- 0 until rawSplits.size) {
@@ -58,7 +60,7 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val context = new TaskAttemptContext(conf, attemptId)
+ val context = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context)
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 3103d7889b..98ea0c92d6 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import java.io.PrintWriter
import java.util.StringTokenizer
@@ -8,6 +8,12 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import spark.OneToOneDependency
+import spark.RDD
+import spark.SparkEnv
+import spark.Split
+
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 8ef40d8d9e..87a5268f27 100644
--- a/core/src/main/scala/spark/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,7 +1,14 @@
-package spark
+package spark.rdd
import java.util.Random
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+import spark.RDD
+import spark.OneToOneDependency
+import spark.Split
+
+private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index
}
@@ -28,19 +35,21 @@ class SampledRDD[T: ClassManifest](
override def compute(splitIn: Split) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
- val rg = new Random(split.seed)
- // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
if (withReplacement) {
- val oldData = prev.iterator(split.prev).toArray
- val sampleSize = (oldData.size * frac).ceil.toInt
- val sampledData = {
- // all of oldData's indices are candidates, even if sampleSize < oldData.size
- for (i <- 1 to sampleSize)
- yield oldData(rg.nextInt(oldData.size))
+ // For large datasets, the expected number of occurrences of each element in a sample with
+ // replacement is Poisson(frac). We use that to get a count for each element.
+ val poisson = new Poisson(frac, new DRand(split.seed))
+ prev.iterator(split.prev).flatMap { element =>
+ val count = poisson.nextInt()
+ if (count == 0) {
+ Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
+ } else {
+ Iterator.fill(count)(element)
+ }
}
- sampledData.iterator
} else { // Sampling without replacement
- prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
+ val rand = new Random(split.seed)
+ prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
}
}
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
new file mode 100644
index 0000000000..145e419c53
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -0,0 +1,40 @@
+package spark.rdd
+
+import spark.Partitioner
+import spark.RDD
+import spark.ShuffleDependency
+import spark.SparkEnv
+import spark.Split
+
+private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
+ override val index = idx
+ override def hashCode(): Int = idx
+}
+
+/**
+ * The resulting RDD from a shuffle (e.g. repartitioning of data).
+ * @param parent the parent RDD.
+ * @param part the partitioner used to partition the RDD
+ * @tparam K the key class.
+ * @tparam V the value class.
+ */
+class ShuffledRDD[K, V](
+ @transient parent: RDD[(K, V)],
+ part: Partitioner) extends RDD[(K, V)](parent.context) {
+
+ override val partitioner = Some(part)
+
+ @transient
+ val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
+
+ override def splits = splits_
+
+ override def preferredLocations(split: Split) = Nil
+
+ val dep = new ShuffleDependency(parent, part)
+ override val dependencies = List(dep)
+
+ override def compute(split: Split): Iterator[(K, V)] = {
+ SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
+ }
+}
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 0e8164d6ab..f0b9225f7c 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,8 +1,14 @@
-package spark
+package spark.rdd
import scala.collection.mutable.ArrayBuffer
-class UnionSplit[T: ClassManifest](
+import spark.Dependency
+import spark.RangeDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+private[spark] class UnionSplit[T: ClassManifest](
idx: Int,
rdd: RDD[T],
split: Split)
@@ -37,7 +43,7 @@ class UnionRDD[T: ClassManifest](
override val dependencies = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
- for ((rdd, index) <- rdds.zipWithIndex) {
+ for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
index 0ecff9ce77..5a4e9a582d 100644
--- a/core/src/main/scala/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -5,11 +5,12 @@ import spark.TaskContext
/**
* Tracks information about an active job in the DAGScheduler.
*/
-class ActiveJob(
+private[spark] class ActiveJob(
val runId: Int,
val finalStage: Stage,
val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int],
+ val callSite: String,
val listener: JobListener) {
val numPartitions = partitions.length
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index f7472971b5..aaaed59c4a 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -21,6 +21,7 @@ import spark.storage.BlockManagerId
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
+private[spark]
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
taskSched.setListener(this)
@@ -38,6 +39,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(HostLost(host))
}
+ // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
+ override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ eventQueue.put(TaskSetFailed(taskSet, reason))
+ }
+
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
@@ -98,7 +104,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
+ def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -113,10 +119,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
+ def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
- logInfo("Registering RDD " + rdd.id + ": " + rdd)
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
@@ -139,11 +145,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + ": " + r)
+ logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
+ case shufDep: ShuffleDependency[_,_] =>
parents += getShuffleMapStage(shufDep, priority)
case _ =>
visit(dep.rdd)
@@ -166,7 +172,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (locs(p) == Nil) {
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
+ case shufDep: ShuffleDependency[_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
missing += mapStage
@@ -183,23 +189,25 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
- def runJob[T, U](
+ def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
+ callSite: String,
allowLocal: Boolean)
- (implicit m: ClassManifest[U]): Array[U] =
+ : Array[U] =
{
if (partitions.size == 0) {
return new Array[U](0)
}
val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
+ eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
waiter.getResult() match {
case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray
case JobFailed(exception: Exception) =>
+ logInfo("Failed to run " + callSite)
throw exception
}
}
@@ -208,13 +216,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] =
+ callSite: String,
+ timeout: Long)
+ : PartialResult[R] =
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails
}
@@ -234,13 +243,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, listener)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
updateCacheLocs()
- logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
- logInfo("Final stage: " + finalStage)
+ logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
+ " output partitions")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
@@ -258,6 +268,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case completion: CompletionEvent =>
handleTaskCompletion(completion)
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(idToStage(taskSet.stageId), reason)
+
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
@@ -329,7 +342,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
- logInfo("Submitting " + stage + ", which has no missing parents")
+ logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@@ -409,14 +422,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
- val bmAddress = event.result.asInstanceOf[BlockManagerId]
- val host = bmAddress.ip
+ val status = event.result.asInstanceOf[MapStatus]
+ val host = status.address.ip
logInfo("ShuffleMapTask finished with host " + host)
if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
- stage.addOutputLoc(smt.partition, bmAddress)
+ stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
- logInfo(stage + " finished; looking for newly runnable stages")
+ logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
running -= stage
logInfo("running: " + running)
logInfo("waiting: " + waiting)
@@ -430,7 +443,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
- logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
+ logInfo("Resubmitting " + stage + " (" + stage.origin +
+ ") because some of its tasks had failed: " +
stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
submitStage(stage)
} else {
@@ -444,6 +458,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
+ logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
submitMissingTasks(stage)
}
}
@@ -460,12 +475,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
running -= failedStage
failed += failedStage
// TODO: Cancel running tasks in the stage
- logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
+ logInfo("Marking " + failedStage + " (" + failedStage.origin +
+ ") for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
- logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
+ logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin +
+ "); marking it for resubmission")
failed += mapStage
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
@@ -475,18 +492,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
handleHostLost(bmAddress.ip)
}
- case _ =>
- // Non-fetch failure -- probably a bug in the job, so bail out
- // TODO: Cancel all tasks that are still running
- resultStageToJob.get(stage) match {
- case Some(job) =>
- val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
- job.listener.jobFailed(error)
- activeJobs -= job
- resultStageToJob -= stage
- case None =>
- logInfo("Ignoring result from " + task + " because its job has finished")
- }
+ case other =>
+ // Non-fetch failure -- probably a bug in user code; abort all jobs depending on this stage
+ abortStage(idToStage(task.stageId), task + " failed: " + other)
}
}
@@ -509,6 +517,53 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
updateCacheLocs()
}
}
+
+ /**
+ * Aborts all jobs depending on a particular Stage. This is called in response to a task set
+ * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ */
+ def abortStage(failedStage: Stage, reason: String) {
+ val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
+ for (resultStage <- dependentStages) {
+ val job = resultStageToJob(resultStage)
+ job.listener.jobFailed(new SparkException("Job failed: " + reason))
+ activeJobs -= job
+ resultStageToJob -= resultStage
+ }
+ if (dependentStages.isEmpty) {
+ logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
+ }
+ }
+
+ /**
+ * Return true if one of stage's ancestors is target.
+ */
+ def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ if (stage == target) {
+ return true
+ }
+ val visitedRdds = new HashSet[RDD[_]]
+ val visitedStages = new HashSet[Stage]
+ def visit(rdd: RDD[_]) {
+ if (!visitedRdds(rdd)) {
+ visitedRdds += rdd
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ if (!mapStage.isAvailable) {
+ visitedStages += mapStage
+ visit(mapStage.rdd)
+ } // Otherwise there's no need to follow the dependency back
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ visitedRdds.contains(target.rdd)
+ }
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index 0fc73059c3..3422a21d9d 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -10,23 +10,26 @@ import spark._
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
-sealed trait DAGSchedulerEvent
+private[spark] sealed trait DAGSchedulerEvent
-case class JobSubmitted(
+private[spark] case class JobSubmitted(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
allowLocal: Boolean,
+ callSite: String,
listener: JobListener)
extends DAGSchedulerEvent
-case class CompletionEvent(
+private[spark] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
-case class HostLost(host: String) extends DAGSchedulerEvent
+private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
-case object StopDAGScheduler extends DAGSchedulerEvent
+private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+
+private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala
index d4dd536a7d..f46b9d551d 100644
--- a/core/src/main/scala/spark/scheduler/JobListener.scala
+++ b/core/src/main/scala/spark/scheduler/JobListener.scala
@@ -5,7 +5,7 @@ package spark.scheduler
* DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
* job fails (and no further taskSucceeded events will happen).
*/
-trait JobListener {
+private[spark] trait JobListener {
def taskSucceeded(index: Int, result: Any)
def jobFailed(exception: Exception)
}
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
index 62b458eccb..c4a74e526f 100644
--- a/core/src/main/scala/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -3,7 +3,7 @@ package spark.scheduler
/**
* A result of a job in the DAGScheduler.
*/
-sealed trait JobResult
+private[spark] sealed trait JobResult
-case class JobSucceeded(results: Seq[_]) extends JobResult
-case class JobFailed(exception: Exception) extends JobResult
+private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
+private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
index 4c2ae23051..b3d4feebe5 100644
--- a/core/src/main/scala/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -5,7 +5,7 @@ import scala.collection.mutable.ArrayBuffer
/**
* An object that waits for a DAGScheduler job to complete.
*/
-class JobWaiter(totalTasks: Int) extends JobListener {
+private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
private var finishedTasks = 0
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
new file mode 100644
index 0000000000..4532d9497f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -0,0 +1,27 @@
+package spark.scheduler
+
+import spark.storage.BlockManagerId
+import java.io.{ObjectOutput, ObjectInput, Externalizable}
+
+/**
+ * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
+ * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
+ * The map output sizes are compressed using MapOutputTracker.compressSize.
+ */
+private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte])
+ extends Externalizable {
+
+ def this() = this(null, null) // For deserialization only
+
+ def writeExternal(out: ObjectOutput) {
+ address.writeExternal(out)
+ out.writeInt(compressedSizes.length)
+ out.write(compressedSizes)
+ }
+
+ def readExternal(in: ObjectInput) {
+ address = new BlockManagerId(in)
+ compressedSizes = new Array[Byte](in.readInt())
+ in.readFully(compressedSizes)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 090ced9d76..2ebd4075a2 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -2,7 +2,7 @@ package spark.scheduler
import spark._
-class ResultTask[T, U](
+private[spark] class ResultTask[T, U](
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index f1eae9bc88..60105c42b6 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -1,10 +1,10 @@
package spark.scheduler
import java.io._
-import java.util.HashMap
+import java.util.{HashMap => JHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -15,11 +15,14 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
-object ShuffleMapTask {
- val serializedInfoCache = new HashMap[Int, Array[Byte]]
- val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
+private[spark] object ShuffleMapTask {
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId)
if (old != null) {
@@ -38,40 +41,40 @@ object ShuffleMapTask {
}
}
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
synchronized {
- val old = deserializedInfoCache.get(stageId)
- if (old != null) {
- return old
- } else {
- val loader = Thread.currentThread.getContextClassLoader
- val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
- val objIn = ser.deserializeStream(in)
- val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
- val tuple = (rdd, dep)
- deserializedInfoCache.put(stageId, tuple)
- return tuple
- }
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+ return (rdd, dep)
}
}
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in)
+ val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+ return (HashMap(set.toSeq: _*))
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
- deserializedInfoCache.clear()
}
}
}
-class ShuffleMapTask(
+private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
- var dep: ShuffleDependency[_,_,_],
+ var dep: ShuffleDependency[_,_],
var partition: Int,
@transient var locs: Seq[String])
- extends Task[BlockManagerId](stageId)
+ extends Task[MapStatus](stageId)
with Externalizable
with Logging {
@@ -106,32 +109,31 @@ class ShuffleMapTask(
split = in.readObject().asInstanceOf[Split]
}
- override def run(attemptId: Long): BlockManagerId = {
+ override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
- val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+
+ // Partition the map output.
+ val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split)) {
- val (k, v) = elem.asInstanceOf[(Any, Any)]
- var bucketId = partitioner.getPartition(k)
- val bucket = buckets(bucketId)
- var existing = bucket.get(k)
- if (existing == null) {
- bucket.put(k, aggregator.createCombiner(v))
- } else {
- bucket.put(k, aggregator.mergeValue(existing, v))
- }
+ val pair = elem.asInstanceOf[(Any, Any)]
+ val bucketId = partitioner.getPartition(pair._1)
+ buckets(bucketId) += pair
}
- val ser = SparkEnv.get.serializer.newInstance()
+ val bucketIterators = buckets.map(_.iterator)
+
+ val compressedSizes = new Array[Byte](numOutputSplits)
+
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
- val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a scala iterator from java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- // TODO: This should probably be DISK_ONLY
- blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false)
+ val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
+ // Get a Scala iterator from Java map
+ val iter: Iterator[(Any, Any)] = bucketIterators(i)
+ val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ compressedSizes(i) = MapOutputTracker.compressSize(size)
}
- return SparkEnv.get.blockManager.blockManagerId
+
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index cd660c9085..4846b66729 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -19,39 +19,39 @@ import spark.storage.BlockManagerId
* Each Stage also has a priority, which is (by default) based on the job it was submitted in.
* This allows Stages from earlier jobs to be computed first or recovered faster on failure.
*/
-class Stage(
+private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
- val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage
+ val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val priority: Int)
extends Logging {
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.splits.size
- val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
+ val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
private var nextAttemptId = 0
def isAvailable: Boolean = {
- if (/*parents.size == 0 &&*/ !isShuffleMap) {
+ if (!isShuffleMap) {
true
} else {
numAvailableOutputs == numPartitions
}
}
- def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ def addOutputLoc(partition: Int, status: MapStatus) {
val prevList = outputLocs(partition)
- outputLocs(partition) = bmAddress :: prevList
+ outputLocs(partition) = status :: prevList
if (prevList == Nil)
numAvailableOutputs += 1
}
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_ == bmAddress)
+ val newList = prevList.filterNot(_.address == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
@@ -62,7 +62,7 @@ class Stage(
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
- val newList = prevList.filterNot(_.ip == host)
+ val newList = prevList.filterNot(_.address.ip == host)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
@@ -80,6 +80,8 @@ class Stage(
return id
}
+ def origin: String = rdd.origin
+
override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
override def hashCode(): Int = id
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index f84d8d9c4f..ef987fdeb6 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,11 +1,95 @@
package spark.scheduler
+import scala.collection.mutable.HashMap
+import spark.serializer.{SerializerInstance, Serializer}
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import spark.util.ByteBufferInputStream
+import scala.collection.mutable.HashMap
+
/**
* A task to execute on a worker node.
*/
-abstract class Task[T](val stageId: Int) extends Serializable {
+private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
def run(attemptId: Long): T
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
}
+
+/**
+ * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We
+ * need to send the list of JARs and files added to the SparkContext with each task to ensure that
+ * worker nodes find out about it, but we can't make it part of the Task because the user's code in
+ * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by
+ * first writing out its dependencies.
+ */
+private[spark] object Task {
+ /**
+ * Serialize a task and the current app dependencies (files and JARs added to the SparkContext)
+ */
+ def serializeWithDependencies(
+ task: Task[_],
+ currentFiles: HashMap[String, Long],
+ currentJars: HashMap[String, Long],
+ serializer: SerializerInstance)
+ : ByteBuffer = {
+
+ val out = new FastByteArrayOutputStream(4096)
+ val dataOut = new DataOutputStream(out)
+
+ // Write currentFiles
+ dataOut.writeInt(currentFiles.size)
+ for ((name, timestamp) <- currentFiles) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
+ }
+
+ // Write currentJars
+ dataOut.writeInt(currentJars.size)
+ for ((name, timestamp) <- currentJars) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
+ }
+
+ // Write the task itself and finish
+ dataOut.flush()
+ val taskBytes = serializer.serialize(task).array()
+ out.write(taskBytes)
+ out.trim()
+ ByteBuffer.wrap(out.array)
+ }
+
+ /**
+ * Deserialize the list of dependencies in a task serialized with serializeWithDependencies,
+ * and return the task itself as a serialized ByteBuffer. The caller can then update its
+ * ClassLoaders and deserialize the task.
+ *
+ * @return (taskFiles, taskJars, taskBytes)
+ */
+ def deserializeWithDependencies(serializedTask: ByteBuffer)
+ : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+
+ val in = new ByteBufferInputStream(serializedTask)
+ val dataIn = new DataInputStream(in)
+
+ // Read task's files
+ val taskFiles = new HashMap[String, Long]()
+ val numFiles = dataIn.readInt()
+ for (i <- 0 until numFiles) {
+ taskFiles(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Read task's JARs
+ val taskJars = new HashMap[String, Long]()
+ val numJars = dataIn.readInt()
+ for (i <- 0 until numJars) {
+ taskJars(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Create a sub-buffer for the rest of the data, which is the serialized Task object
+ val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
+ (taskFiles, taskJars, subBuffer)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 868ddb237c..9a54d0e854 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -7,6 +7,7 @@ import scala.collection.mutable.Map
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
+private[spark]
class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
def this() = this(null.asInstanceOf[T], null)
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
index c35633d53c..d549b184b0 100644
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -7,7 +7,7 @@ package spark.scheduler
* are failures, and mitigating stragglers. They return events to the DAGScheduler through
* the TaskSchedulerListener interface.
*/
-trait TaskScheduler {
+private[spark] trait TaskScheduler {
def start(): Unit
// Disconnect from the cluster.
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index a647eec9e4..fa4de15d0d 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -7,10 +7,13 @@ import spark.TaskEndReason
/**
* Interface for getting events back from the TaskScheduler.
*/
-trait TaskSchedulerListener {
+private[spark] trait TaskSchedulerListener {
// A task has finished or failed.
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster.
def hostLost(host: String): Unit
+
+ // The TaskScheduler wants to abort an entire task set.
+ def taskSetFailed(taskSet: TaskSet, reason: String): Unit
}
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
index 6f29dd2e9d..a3002ca477 100644
--- a/core/src/main/scala/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -4,6 +4,8 @@ package spark.scheduler
* A set of tasks submitted together to the low-level TaskScheduler, usually representing
* missing partitions of a particular stage.
*/
-class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
+private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
val id: String = stageId + "." + attempt
+
+ override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 20c82ad0fa..f5e852d203 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -16,7 +16,7 @@ import java.util.concurrent.atomic.AtomicLong
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
* start(), then submit task sets through the runTasks method.
*/
-class ClusterScheduler(sc: SparkContext)
+private[spark] class ClusterScheduler(val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext)
def initialize(context: SchedulerBackend) {
backend = context
- createJarServer()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
@@ -236,32 +235,7 @@ class ClusterScheduler(sc: SparkContext)
}
override def defaultParallelism() = backend.defaultParallelism()
-
- // Create a server for all the JARs added by the user to SparkContext.
- // We first copy the JARs to a temp directory for easier server setup.
- private def createJarServer() {
- val jarDir = Utils.createTempDir()
- logInfo("Temp directory for JARs: " + jarDir)
- val filenames = ArrayBuffer[String]()
- // Copy each JAR to a unique filename in the jarDir
- for ((path, index) <- sc.jars.zipWithIndex) {
- val file = new File(path)
- if (file.exists) {
- val filename = index + "_" + file.getName
- Utils.copyFile(file, new File(jarDir, filename))
- filenames += filename
- }
- }
- // Create the server
- jarServer = new HttpServer(jarDir)
- jarServer.start()
- // Build up the jar URI list
- val serverUri = jarServer.uri
- jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
- System.setProperty("spark.jar.uris", jarUris)
- logInfo("JAR server started at " + serverUri)
- }
-
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index 897976c3f9..ddcd64d7c6 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -5,7 +5,7 @@ package spark.scheduler.cluster
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
-trait SchedulerBackend {
+private[spark] trait SchedulerBackend {
def start(): Unit
def stop(): Unit
def reviveOffers(): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
index e15d577a8b..96ebaa4601 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
@@ -1,3 +1,4 @@
package spark.scheduler.cluster
+private[spark]
class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 0bd2d15479..7aba7324ab 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -5,7 +5,7 @@ import spark.deploy.client.{Client, ClientListener}
import spark.deploy.{Command, JobDescription}
import scala.collection.mutable.HashMap
-class SparkDeploySchedulerBackend(
+private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
@@ -16,17 +16,10 @@ class SparkDeploySchedulerBackend(
var client: Client = null
var stopping = false
+ var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
- // Environment variables to pass to our executors
- val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
- "SPARK_MEM",
- "SPARK_CLASSPATH",
- "SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
- )
-
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
@@ -40,17 +33,11 @@ class SparkDeploySchedulerBackend(
override def start() {
super.start()
- val environment = new HashMap[String, String]
- for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
- if (System.getenv(key) != null) {
- environment(key) = System.getenv(key)
- }
- }
val masterUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
- val command = Command("spark.executor.StandaloneExecutorBackend", args, environment)
+ val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
@@ -61,6 +48,9 @@ class SparkDeploySchedulerBackend(
stopping = true;
super.stop()
client.stop()
+ if (shutdownCallback != null) {
+ shutdownCallback(this)
+ }
}
def connected(jobId: String) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index 80e8733671..1386cd9d44 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -4,19 +4,27 @@ import spark.TaskState.TaskState
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
-sealed trait StandaloneClusterMessage extends Serializable
+private[spark] sealed trait StandaloneClusterMessage extends Serializable
// Master to slaves
+private[spark]
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+
+private[spark]
case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage
+
+private[spark]
case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage
// Slaves to master
+private[spark]
case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage
+private[spark]
case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
extends StandaloneClusterMessage
+private[spark]
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
@@ -25,5 +33,5 @@ object StatusUpdate {
}
// Internal messages in master
-case object ReviveOffers extends StandaloneClusterMessage
-case object StopMaster extends StandaloneClusterMessage
+private[spark] case object ReviveOffers extends StandaloneClusterMessage
+private[spark] case object StopMaster extends StandaloneClusterMessage
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 978b4f2676..d2cce0dc05 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -16,6 +16,7 @@ import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClient
* Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
* Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
*/
+private[spark]
class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging {
@@ -99,7 +100,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Remove a disconnected slave from the cluster
def removeSlave(slaveId: String) {
- logWarning("Slave " + slaveId + " disconnected, so removing it")
+ logInfo("Slave " + slaveId + " disconnected, so removing it")
val numCores = freeCores(slaveId)
actorToSlaveId -= slaveActor(slaveId)
addressToSlaveId -= slaveAddress(slaveId)
@@ -149,6 +150,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
}
-object StandaloneSchedulerBackend {
+private[spark] object StandaloneSchedulerBackend {
val ACTOR_NAME = "StandaloneScheduler"
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
index f9a1b74fa5..aa097fd3a2 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
@@ -3,7 +3,7 @@ package spark.scheduler.cluster
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
-class TaskDescription(
+private[spark] class TaskDescription(
val taskId: Long,
val slaveId: String,
val name: String,
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index 65e59841a9..ca84503780 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -3,6 +3,7 @@ package spark.scheduler.cluster
/**
* Information about a running task attempt inside a TaskSet.
*/
+private[spark]
class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) {
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 5a7df6040c..cf4aae03a7 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -17,7 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
-class TaskSetManager(
+private[spark] class TaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Logging {
@@ -214,7 +214,8 @@ class TaskSetManager(
}
// Serialize and return the task
val startTime = System.currentTimeMillis
- val serializedTask = ser.serialize(task)
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
@@ -243,6 +244,11 @@ class TaskSetManager(
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
val index = info.index
info.markSuccessful()
if (!finished(index)) {
@@ -335,13 +341,14 @@ class TaskSetManager(
def error(message: String) {
// Save the error message
- abort("Mesos error: " + message)
+ abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.listener.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 1e83f103e7..6b919d68b2 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -3,5 +3,6 @@ package spark.scheduler.cluster
/**
* Represents free resources available on a worker node.
*/
+private[spark]
class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb47988f0c..eb20fe41b2 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,9 +1,13 @@
package spark.scheduler.local
+import java.io.File
+import java.net.URLClassLoader
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.mutable.HashMap
import spark._
+import executor.ExecutorURLClassLoader
import spark.scheduler._
/**
@@ -11,17 +15,27 @@ import spark.scheduler._
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
+private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
+ extends TaskScheduler
+ with Logging {
+
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
+
+ val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
+
// TODO: Need to take into account stage priority in scheduling
- override def start() {}
+ override def start() { }
- override def setListener(listener: TaskSchedulerListener) {
+ override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
@@ -43,21 +57,29 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- Accumulators.clear
val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = ser.serialize(task)
+ val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserializedTask = ser.deserialize[Task[_]](
- bytes, Thread.currentThread.getContextClassLoader)
+ taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // Run it
val result: Any = deserializedTask.run(attemptId)
+
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
- val accumUpdates = Accumulators.values
+ val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
+ ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
@@ -80,7 +102,32 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
submitTask(task, i)
}
}
-
+
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File("."))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(".", localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
+ }
+ }
+
override def stop() {
threadPool.shutdownNow()
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index fdf007ffb2..c45c7df69c 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -24,7 +24,7 @@ import spark.TaskState
* Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
* remove this.
*/
-class CoarseMesosSchedulerBackend(
+private[spark] class CoarseMesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
@@ -33,14 +33,6 @@ class CoarseMesosSchedulerBackend(
with MScheduler
with Logging {
- // Environment variables to pass to our executors
- val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
- "SPARK_MEM",
- "SPARK_CLASSPATH",
- "SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
- )
-
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
// Memory used by each executor (in megabytes)
@@ -122,13 +114,11 @@ class CoarseMesosSchedulerBackend(
val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
val environment = Environment.newBuilder()
- for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
- if (System.getenv(key) != null) {
- environment.addVariables(Environment.Variable.newBuilder()
- .setName(key)
- .setValue(System.getenv(key))
- .build())
- }
+ sc.executorEnvs.foreach { case (key, value) =>
+ environment.addVariables(Environment.Variable.newBuilder()
+ .setName(key)
+ .setValue(value)
+ .build())
}
return CommandInfo.newBuilder().setValue(command).setEnvironment(environment).build()
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 44eda93dd1..cdfe1f2563 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -20,7 +20,7 @@ import spark.TaskState
* separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks
* from multiple apps can run on different cores) and in time (a core can switch ownership).
*/
-class MesosSchedulerBackend(
+private[spark] class MesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
@@ -29,14 +29,6 @@ class MesosSchedulerBackend(
with MScheduler
with Logging {
- // Environment variables to pass to our executors
- val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
- "SPARK_MEM",
- "SPARK_CLASSPATH",
- "SPARK_LIBRARY_PATH",
- "SPARK_JAVA_OPTS"
- )
-
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
@@ -93,13 +85,11 @@ class MesosSchedulerBackend(
}
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
- for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
- if (System.getenv(key) != null) {
- environment.addVariables(Environment.Variable.newBuilder()
- .setName(key)
- .setValue(System.getenv(key))
- .build())
- }
+ sc.executorEnvs.foreach { case (key, value) =>
+ environment.addVariables(Environment.Variable.newBuilder()
+ .setName(key)
+ .setValue(value)
+ .build())
}
val memory = Resource.newBuilder()
.setName("mem")
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index 61a70beaf1..50b086125a 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,23 +1,21 @@
-package spark
+package spark.serializer
-import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
-import java.nio.channels.Channels
-
+import java.io.{EOFException, InputStream, OutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
import spark.util.ByteBufferInputStream
/**
- * A serializer. Because some serialization libraries are not thread safe, this class is used to
- * create SerializerInstances that do the actual serialization.
+ * A serializer. Because some serialization libraries are not thread safe, this class is used to
+ * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
+ * guaranteed to only be called from one thread at a time.
*/
trait Serializer {
def newInstance(): SerializerInstance
}
/**
- * An instance of the serializer, for use by one thread at a time.
+ * An instance of a serializer, for use by one thread at a time.
*/
trait SerializerInstance {
def serialize[T](t: T): ByteBuffer
@@ -43,7 +41,7 @@ trait SerializerInstance {
def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
// Default implementation uses deserializeStream
buffer.rewind()
- deserializeStream(new ByteBufferInputStream(buffer)).toIterator
+ deserializeStream(new ByteBufferInputStream(buffer)).asIterator
}
}
@@ -51,7 +49,7 @@ trait SerializerInstance {
* A stream for writing serialized objects.
*/
trait SerializationStream {
- def writeObject[T](t: T): Unit
+ def writeObject[T](t: T): SerializationStream
def flush(): Unit
def close(): Unit
@@ -74,7 +72,7 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
- def toIterator: Iterator[Any] = new Iterator[Any] {
+ def asIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false
var finished = false
var nextValue: Any = null
@@ -88,7 +86,7 @@ trait DeserializationStream {
}
gotNext = true
}
-
+
override def hasNext: Boolean = {
if (!gotNext) {
getNext()
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 4cdb9710ec..bd9155ef29 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -1,34 +1,29 @@
package spark.storage
-import java.io._
-import java.nio._
-import java.nio.channels.FileChannel.MapMode
-import java.util.{HashMap => JHashMap}
-import java.util.LinkedHashMap
-import java.util.concurrent.ConcurrentHashMap
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.Collections
-
import akka.dispatch.{Await, Future}
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
+import akka.util.Duration
-import it.unimi.dsi.fastutil.io._
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import spark.CacheTracker
-import spark.Logging
-import spark.Serializer
-import spark.SizeEstimator
-import spark.SparkEnv
-import spark.SparkException
-import spark.Utils
-import spark.util.ByteBufferInputStream
+import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput}
+import java.nio.{MappedByteBuffer, ByteBuffer}
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.collection.JavaConversions._
+
+import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils}
import spark.network._
-import akka.util.Duration
+import spark.serializer.Serializer
+import spark.util.ByteBufferInputStream
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+import sun.nio.ch.DirectBuffer
+
-class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0)
+private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
+ def this() = this(null, 0) // For deserialization only
+
+ def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
@@ -51,41 +46,76 @@ class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
}
-case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
+private[spark]
+case class BlockException(blockId: String, message: String, ex: Exception = null)
+extends Exception(message)
-class BlockLocker(numLockers: Int) {
+private[spark] class BlockLocker(numLockers: Int) {
private val hashLocker = Array.fill(numLockers)(new Object())
-
+
def getLock(blockId: String): Object = {
return hashLocker(math.abs(blockId.hashCode % numLockers))
}
}
+private[spark]
class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long)
extends Logging {
- case class BlockInfo(level: StorageLevel, tellMaster: Boolean)
+ class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ var pending: Boolean = true
+ var size: Long = -1L
+
+ /** Wait for this BlockInfo to be marked as ready (i.e. block is finished writing) */
+ def waitForReady() {
+ if (pending) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ pending = false
+ size = sizeInBytes
+ synchronized {
+ this.notifyAll()
+ }
+ }
+ }
private val NUM_LOCKS = 337
private val locker = new BlockLocker(NUM_LOCKS)
private val blockInfo = new ConcurrentHashMap[String, BlockInfo]()
- private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private val diskStore: BlockStore = new DiskStore(this,
- System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
-
+
+ private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private[storage] val diskStore: BlockStore =
+ new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
-
+
val connectionManagerId = connectionManager.id
val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
+
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
- initLogging()
+ // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
+ // for receiving shuffle outputs)
+ val maxBytesInFlight =
+ System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+
+ val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean
+ // Whether to compress RDD partitions that are stored serialized
+ val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean
+
+ val host = System.getProperty("spark.hostname", Utils.localHostName())
initialize()
@@ -102,7 +132,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
*/
private def initialize() {
master.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
+ RegisterBlockManager(blockManagerId, maxMemory))
BlockManagerWorker.startBlockManagerWorker(this)
}
@@ -115,36 +145,32 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
/**
- * Change storage level for a local block and tell master is necesary.
- * If new level is invalid, then block info (if it exists) will be silently removed.
+ * Tell the master about the current storage status of a block. This will send a heartbeat
+ * message reflecting the current status, *not* the desired storage level in its block info.
+ * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
*/
- def setLevel(blockId: String, level: StorageLevel, tellMaster: Boolean = true) {
- if (level == null) {
- throw new IllegalArgumentException("Storage level is null")
- }
-
- // If there was earlier info about the block, then use earlier tellMaster
- val oldInfo = blockInfo.get(blockId)
- val newTellMaster = if (oldInfo != null) oldInfo.tellMaster else tellMaster
- if (oldInfo != null && oldInfo.tellMaster != tellMaster) {
- logWarning("Ignoring tellMaster setting as it is different from earlier setting")
- }
-
- // If level is valid, store the block info, else remove the block info
- if (level.isValid) {
- blockInfo.put(blockId, new BlockInfo(level, newTellMaster))
- logDebug("Info for block " + blockId + " updated with new level as " + level)
- } else {
- blockInfo.remove(blockId)
- logDebug("Info for block " + blockId + " removed as new level is null or invalid")
- }
-
- // Tell master if necessary
- if (newTellMaster) {
+ def reportBlockStatus(blockId: String) {
+ locker.getLock(blockId).synchronized {
+ val curLevel = blockInfo.get(blockId) match {
+ case null =>
+ StorageLevel.NONE
+ case info =>
+ info.level match {
+ case null =>
+ StorageLevel.NONE
+ case level =>
+ val inMem = level.useMemory && memoryStore.contains(blockId)
+ val onDisk = level.useDisk && diskStore.contains(blockId)
+ new StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ }
+ }
+ master.mustHeartBeat(HeartBeat(
+ blockManagerId,
+ blockId,
+ curLevel,
+ if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L,
+ if (curLevel.useDisk) diskStore.getSize(blockId) else 0L))
logDebug("Told master about block " + blockId)
- notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
- } else {
- logDebug("Did not tell master about block " + blockId)
}
}
@@ -174,55 +200,149 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
* Get block from local block manager.
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
logDebug("Getting local block " + blockId)
+
+ // As an optimization for map output fetches, if the block is for a shuffle, return it
+ // without acquiring a lock; the disk store never deletes (recent) items so this should work
+ if (blockId.startsWith("shuffle_")) {
+ return diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ Some(iterator)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
locker.getLock(blockId).synchronized {
-
- // Check storage level of block
- val level = getLevel(blockId)
- if (level != null) {
- logDebug("Level for block " + blockId + " is " + level + " on local machine")
-
+ val info = blockInfo.get(blockId)
+ if (info != null) {
+ info.waitForReady() // In case the block is still being put() by another thread
+ val level = info.level
+ logDebug("Level for block " + blockId + " is " + level)
+
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
memoryStore.getValues(blockId) match {
- case Some(iterator) => {
- logDebug("Block " + blockId + " found in memory")
+ case Some(iterator) =>
return Some(iterator)
- }
- case None => {
+ case None =>
logDebug("Block " + blockId + " not found in memory")
- }
}
- } else {
- logDebug("Not getting block " + blockId + " from memory")
}
- // Look for block in disk
+ // Look for block on disk, potentially loading it back into memory if required
if (level.useDisk) {
logDebug("Getting block " + blockId + " from disk")
- diskStore.getValues(blockId) match {
- case Some(iterator) => {
- logDebug("Block " + blockId + " found in disk")
- return Some(iterator)
+ if (level.useMemory && level.deserialized) {
+ diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ // Put the block back in memory before returning it
+ // TODO: Consider creating a putValues that also takes in a iterator ?
+ val elements = new ArrayBuffer[Any]
+ elements ++= iterator
+ memoryStore.putValues(blockId, elements, level, true).data match {
+ case Left(iterator2) =>
+ return Some(iterator2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ } else if (level.useMemory && !level.deserialized) {
+ // Read it as a byte buffer into memory first, then return it
+ diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ // Put a copy of the block back in memory before returning it. Note that we can't
+ // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ return Some(dataDeserialize(blockId, bytes))
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
- case None => {
- throw new Exception("Block " + blockId + " not found in disk")
- return None
+ } else {
+ diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ return Some(iterator)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
- } else {
- logDebug("Not getting block " + blockId + " from disk")
}
+ } else {
+ logDebug("Block " + blockId + " not registered locally")
+ }
+ }
+ return None
+ }
+
+ /**
+ * Get block from the local block manager as serialized bytes.
+ */
+ def getLocalBytes(blockId: String): Option[ByteBuffer] = {
+ // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
+ logDebug("Getting local block " + blockId + " as bytes")
+
+ // As an optimization for map output fetches, if the block is for a shuffle, return it
+ // without acquiring a lock; the disk store never deletes (recent) items so this should work
+ if (blockId.startsWith("shuffle_")) {
+ return diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ Some(bytes)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
+ locker.getLock(blockId).synchronized {
+ val info = blockInfo.get(blockId)
+ if (info != null) {
+ info.waitForReady() // In case the block is still being put() by another thread
+ val level = info.level
+ logDebug("Level for block " + blockId + " is " + level)
+ // Look for the block in memory
+ if (level.useMemory) {
+ logDebug("Getting block " + blockId + " from memory")
+ memoryStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ return Some(bytes)
+ case None =>
+ logDebug("Block " + blockId + " not found in memory")
+ }
+ }
+
+ // Look for block on disk
+ if (level.useDisk) {
+ // Read it as a byte buffer into memory first, then return it
+ diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ if (level.useMemory) {
+ if (level.deserialized) {
+ memoryStore.putBytes(blockId, bytes, level)
+ } else {
+ // The memory store will hang onto the ByteBuffer, so give it a copy instead of
+ // the memory-mapped file buffer we got from the disk store
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ }
+ }
+ bytes.rewind()
+ return Some(bytes)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
} else {
- logDebug("Level for block " + blockId + " not found")
+ logDebug("Block " + blockId + " not registered locally")
}
- }
- return None
+ }
+ return None
}
/**
@@ -243,7 +363,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
if (data != null) {
logDebug("Data is not null: " + data)
- return Some(dataDeserialize(data))
+ return Some(dataDeserialize(blockId, data))
}
logDebug("Data is null")
}
@@ -261,9 +381,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
/**
* Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
* an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
- * fashion as they're received.
+ * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
+ * so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[String])])
+ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
: Iterator[(String, Option[Iterator[Any]])] = {
if (blocksByAddress == null) {
@@ -273,71 +394,127 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
logDebug("Getting " + totalBlocks + " blocks")
var startTime = System.currentTimeMillis
val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new ArrayBuffer[String]()
- val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
+ val remoteBlockIds = new HashSet[String]()
- // A queue to hold our results. Because we want all the deserializing the happen in the
- // caller's thread, this will actually hold functions to produce the Iterator for each block.
- // For local blocks we'll have an iterator already, while for remote ones we'll deserialize.
- val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])]
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
- // Split local and remote blocks
- for ((address, blockIds) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockIds
- } else {
- remoteBlockIds ++= blockIds
- remoteBlockIdsPerLocation(address) = blockIds
- }
+ // A queue to hold our results.
+ val results = new LinkedBlockingQueue[FetchResult]
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
}
-
- // Start getting remote blocks
- for ((bmId, bIds) <- remoteBlockIdsPerLocation) {
- val cmId = ConnectionManagerId(bmId.ip, bmId.port)
- val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
- val blockMessageArray = new BlockMessageArray(blockMessages)
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ var bytesInFlight = 0L
+
+ def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+ val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- blockMessageArray.foreach(blockMessage => {
+ for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put((blockId, Some(() => dataDeserialize(blockMessage.getData))))
+ results.put(new FetchResult(
+ blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- })
+ }
}
case None => {
- logError("Could not get blocks from " + cmId)
- for (blockId <- bIds) {
- results.put((blockId, None))
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ localBlockIds ++= blockInfos.map(_._1)
+ } else {
+ remoteBlockIds ++= blockInfos.map(_._1)
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ curBlocks += ((blockId, size))
+ curRequestSize += size
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
}
}
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
}
}
- logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " +
- Utils.getUsedTimeMs(startTime) + " ms")
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
- // Get the local blocks while remote blocks are being fetched
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteBlockIds.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
- localBlockIds.foreach(id => {
+ for (id <- localBlockIds) {
getLocal(id) match {
- case Some(block) => {
- results.put((id, Some(() => block)))
+ case Some(iter) => {
+ results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
}
case None => {
throw new BlockException(id, "Could not get block " + id + " from local machine")
}
}
- })
+ }
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- // Return an iterator that will read fetched blocks off the queue as they arrive
+ // Return an iterator that will read fetched blocks off the queue as they arrive.
return new Iterator[(String, Option[Iterator[Any]])] {
var resultsGotten = 0
@@ -345,16 +522,30 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
- val (blockId, functionOption) = results.take()
- (blockId, functionOption.map(_.apply()))
+ val result = results.take()
+ bytesInFlight -= result.size
+ if (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
}
+ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ : Long = {
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ put(blockId, elements, level, tellMaster)
+ }
+
/**
- * Put a new block of values to the block manager.
+ * Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
+ def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ tellMaster: Boolean = true) : Long = {
+
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
@@ -365,81 +556,97 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val startTimeMs = System.currentTimeMillis
- var bytes: ByteBuffer = null
+ val oldBlock = blockInfo.get(blockId)
+ if (oldBlock != null) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ oldBlock.waitForReady()
+ return oldBlock.size
+ }
+
+ // Remember the block's storage level so that we can correctly drop it to disk if it needs
+ // to be dropped right after it got put into memory. Note, however, that other threads will
+ // not be able to get() this block until we call markReady on its BlockInfo.
+ val myInfo = new BlockInfo(level, tellMaster)
+ blockInfo.put(blockId, myInfo)
+
+ val startTimeMs = System.currentTimeMillis
// If we need to replicate the data, we'll want access to the values, but because our
// put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where
- // it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator.
+ // the put serializes data, we'll remember the bytes, above; but for the case where it
+ // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
-
+
+ // Ditto for the bytes after the put
+ var bytesAfterPut: ByteBuffer = null
+
+ // Size of the block in bytes (to return to caller)
+ var size = 0L
+
locker.getLock(blockId).synchronized {
logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
-
- // Check and warn if block with same id already exists
- if (getLevel(blockId) != null) {
- logWarning("Block " + blockId + " already exists in local machine")
- return
- }
- if (level.useMemory && level.useDisk) {
- // If saving to both memory and disk, then serialize only once
- memoryStore.putValues(blockId, values, level) match {
- case Left(newValues) =>
- diskStore.putValues(blockId, newValues, level) match {
- case Right(newBytes) => bytes = newBytes
- case _ => throw new Exception("Unexpected return value")
- }
- case Right(newBytes) =>
- bytes = newBytes
- diskStore.putBytes(blockId, newBytes, level)
- }
- } else if (level.useMemory) {
- // If only save to memory
- memoryStore.putValues(blockId, values, level) match {
- case Right(newBytes) => bytes = newBytes
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will later
+ // drop it to disk if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
case Left(newIterator) => valuesAfterPut = newIterator
}
} else {
- // If only save to disk
- diskStore.putValues(blockId, values, level) match {
- case Right(newBytes) => bytes = newBytes
- case _ => throw new Exception("Unexpected return value")
+ // Save directly to disk.
+ val askForBytes = level.replication > 1 // Don't get back the bytes unless we replicate them
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
}
}
-
- // Store the storage level
- setLevel(blockId, level, tellMaster)
+
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ myInfo.markReady(size)
+ if (tellMaster) {
+ reportBlockStatus(blockId)
+ }
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Replicate block if required
if (level.replication > 1) {
// Serialize the block if not already done
- if (bytes == null) {
+ if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
throw new SparkException(
"Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
}
- bytes = dataSerialize(valuesAfterPut)
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
- replicate(blockId, bytes, level)
+ replicate(blockId, bytesAfterPut, level)
}
+ BlockManager.dispose(bytesAfterPut)
+
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
+ notifyCacheTracker(blockId)
}
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
+
+ return size
}
/**
* Put a new block of serialized bytes to the block manager.
*/
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+ def putBytes(
+ blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
@@ -449,10 +656,21 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
if (level == null || !level.isValid) {
throw new IllegalArgumentException("Storage level is null or invalid")
}
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
+
+ if (blockInfo.containsKey(blockId)) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return
+ }
+
+ // Remember the block's storage level so that we can correctly drop it to disk if it needs
+ // to be dropped right after it got put into memory. Note, however, that other threads will
+ // not be able to get() this block until we call markReady on its BlockInfo.
+ val myInfo = new BlockInfo(level, tellMaster)
+ blockInfo.put(blockId, myInfo)
+
+ val startTimeMs = System.currentTimeMillis
+
+ // Initiate the replication before storing it locally. This is faster as
// data is already serialized and ready for sending
val replicationFuture = if (level.replication > 1) {
val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
@@ -466,27 +684,29 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
locker.getLock(blockId).synchronized {
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
- if (getLevel(blockId) != null) {
- logWarning("Block " + blockId + " already exists")
- return
- }
if (level.useMemory) {
+ // Store it only in memory at first, even if useDisk is also set to true
+ bytes.rewind()
memoryStore.putBytes(blockId, bytes, level)
- }
- if (level.useDisk) {
+ } else {
+ bytes.rewind()
diskStore.putBytes(blockId, bytes, level)
}
- // Store the storage level
- setLevel(blockId, level, tellMaster)
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ myInfo.markReady(bytes.limit)
+ if (tellMaster) {
+ reportBlockStatus(blockId)
+ }
}
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
- notifyTheCacheTracker(blockId)
+ notifyCacheTracker(blockId)
}
-
+
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@@ -495,12 +715,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
Await.ready(replicationFuture, Duration.Inf)
}
- val finishTime = System.currentTimeMillis
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("PutBytes for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("PutBytes for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
}
@@ -508,17 +727,14 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
/**
* Replicate block to another node.
*/
-
- var firstTime = true
- var peers : Seq[BlockManagerId] = null
+ var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
val tLevel: StorageLevel =
new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
- if (firstTime) {
- peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1))
- firstTime = false;
- }
- for (peer: BlockManagerId <- peers) {
+ if (cachedPeers == null) {
+ cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1))
+ }
+ for (peer: BlockManagerId <- cachedPeers) {
val start = System.nanoTime
data.rewind()
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
@@ -534,19 +750,20 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
// TODO: This code will be removed when CacheTracker is gone.
- private def notifyTheCacheTracker(key: String) {
- val rddInfo = key.split(":")
- val rddId: Int = rddInfo(1).toInt
- val splitIndex: Int = rddInfo(2).toInt
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
+ private def notifyCacheTracker(key: String) {
+ if (cacheTracker != null) {
+ val rddInfo = key.split("_")
+ val rddId: Int = rddInfo(1).toInt
+ val partition: Int = rddInfo(2).toInt
+ cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
+ }
}
/**
* Read a block consisting of a single object.
*/
def getSingle(blockId: String): Option[Any] = {
- get(blockId).map(_.next)
+ get(blockId).map(_.next())
}
/**
@@ -557,42 +774,76 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
/**
- * Drop block from memory (called when memory store has reached it limit)
+ * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
+ * store reaches its limit and needs to free up space.
*/
- def dropFromMemory(blockId: String) {
+ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ logInfo("Dropping block " + blockId + " from memory")
locker.getLock(blockId).synchronized {
- val level = getLevel(blockId)
- if (level == null) {
- logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
- return
+ val info = blockInfo.get(blockId)
+ val level = info.level
+ if (level.useDisk && !diskStore.contains(blockId)) {
+ logInfo("Writing block " + blockId + " to disk")
+ data match {
+ case Left(elements) =>
+ diskStore.putValues(blockId, elements, level, false)
+ case Right(bytes) =>
+ diskStore.putBytes(blockId, bytes, level)
+ }
+ }
+ memoryStore.remove(blockId)
+ if (info.tellMaster) {
+ reportBlockStatus(blockId)
}
- if (!level.useMemory) {
- logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
- return
+ if (!level.useDisk) {
+ // The block is completely gone from this node; forget it so we can put() it again later.
+ blockInfo.remove(blockId)
}
- memoryStore.remove(blockId)
- val newLevel = new StorageLevel(level.useDisk, false, level.deserialized, level.replication)
- setLevel(blockId, newLevel)
}
}
- def dataSerialize(values: Iterator[Any]): ByteBuffer = {
- /*serializer.newInstance().serializeMany(values)*/
+ def shouldCompress(blockId: String): Boolean = {
+ if (blockId.startsWith("shuffle_")) {
+ compressShuffle
+ } else if (blockId.startsWith("broadcast_")) {
+ compressBroadcast
+ } else if (blockId.startsWith("rdd_")) {
+ compressRdds
+ } else {
+ false // Won't happen in a real cluster, but it can in tests
+ }
+ }
+
+ /**
+ * Wrap an output stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ if (shouldCompress(blockId)) new LZFOutputStream(s) else s
+ }
+
+ /**
+ * Wrap an input stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ if (shouldCompress(blockId)) new LZFInputStream(s) else s
+ }
+
+ def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
- def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
- /*serializer.newInstance().deserializeMany(bytes)*/
- val ser = serializer.newInstance()
+ /**
+ * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
+ * the iterator is reached.
+ */
+ def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
bytes.rewind()
- return ser.deserializeStream(new ByteBufferInputStream(bytes)).toIterator
- }
-
- private def notifyMaster(heartBeat: HeartBeat) {
- master.mustHeartBeat(heartBeat)
+ val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
+ serializer.newInstance().deserializeStream(stream).asIterator
}
def stop() {
@@ -604,9 +855,25 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
}
}
-object BlockManager {
- def getMaxMemoryFromSystemProperties(): Long = {
+private[spark]
+object BlockManager extends Logging {
+ def getMaxMemoryFromSystemProperties: Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
}
+
+ /**
+ * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
+ * might cause errors if one attempts to read from the unmapped buffer, but it's better than
+ * waiting for the GC to find it because that could lead to huge numbers of open files. There's
+ * unfortunately no standard API to do this.
+ */
+ def dispose(buffer: ByteBuffer) {
+ if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
+ logDebug("Unmapping " + buffer)
+ if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
+ buffer.asInstanceOf[DirectBuffer].cleaner().clean()
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 9f03c5a32c..b3345623b3 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -3,37 +3,35 @@ package spark.storage
import java.io._
import java.util.{HashMap => JHashMap}
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
+import akka.util.{Duration, Timeout}
import akka.util.duration._
-import spark.Logging
-import spark.SparkException
-import spark.Utils
+import spark.{Logging, SparkException, Utils}
+
+private[spark]
sealed trait ToBlockManagerMaster
+private[spark]
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
- maxMemSize: Long,
- maxDiskSize: Long)
+ maxMemSize: Long)
extends ToBlockManagerMaster
-
+
+private[spark]
class HeartBeat(
var blockManagerId: BlockManagerId,
var blockId: String,
var storageLevel: StorageLevel,
- var deserializedSize: Long,
- var size: Long)
+ var memSize: Long,
+ var diskSize: Long)
extends ToBlockManagerMaster
with Externalizable {
@@ -43,8 +41,8 @@ class HeartBeat(
blockManagerId.writeExternal(out)
out.writeUTF(blockId)
storageLevel.writeExternal(out)
- out.writeInt(deserializedSize.toInt)
- out.writeInt(size.toInt)
+ out.writeInt(memSize.toInt)
+ out.writeInt(diskSize.toInt)
}
override def readExternal(in: ObjectInput) {
@@ -53,105 +51,115 @@ class HeartBeat(
blockId = in.readUTF()
storageLevel = new StorageLevel()
storageLevel.readExternal(in)
- deserializedSize = in.readInt()
- size = in.readInt()
+ memSize = in.readInt()
+ diskSize = in.readInt()
}
}
+private[spark]
object HeartBeat {
def apply(blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
- deserializedSize: Long,
- size: Long): HeartBeat = {
- new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+ memSize: Long,
+ diskSize: Long): HeartBeat = {
+ new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize)
}
-
// For pattern-matching
def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
}
}
-
+
+private[spark]
case class GetLocations(blockId: String) extends ToBlockManagerMaster
+private[spark]
case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
-
+
+private[spark]
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
-
+
+private[spark]
case class RemoveHost(host: String) extends ToBlockManagerMaster
+private[spark]
case object StopBlockManagerMaster extends ToBlockManagerMaster
+private[spark]
+case object GetMemoryStatus extends ToBlockManagerMaster
+
+
+private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
-class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
-
class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
timeMs: Long,
- maxMem: Long,
- maxDisk: Long) {
- private var lastSeenMs = timeMs
- private var remainedMem = maxMem
- private var remainedDisk = maxDisk
- private val blocks = new JHashMap[String, StorageLevel]
-
+ val maxMem: Long) {
+ private var _lastSeenMs = timeMs
+ private var _remainingMem = maxMem
+ private val _blocks = new JHashMap[String, StorageLevel]
+
+ logInfo("Registering block manager %s:%d with %s RAM".format(
+ blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+
def updateLastSeenMs() {
- lastSeenMs = System.currentTimeMillis() / 1000
+ _lastSeenMs = System.currentTimeMillis() / 1000
}
-
- def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
- synchronized {
+
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+ : Unit = synchronized {
+
updateLastSeenMs()
-
- if (blocks.containsKey(blockId)) {
- val oriLevel: StorageLevel = blocks.get(blockId)
-
- if (oriLevel.deserialized) {
- remainedMem += deserializedSize
- }
- if (oriLevel.useMemory) {
- remainedMem += size
- }
- if (oriLevel.useDisk) {
- remainedDisk += size
+
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val originalLevel: StorageLevel = _blocks.get(blockId)
+
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
}
}
-
- if (storageLevel.isValid) {
- blocks.put(blockId, storageLevel)
- if (storageLevel.deserialized) {
- remainedMem -= deserializedSize
- }
+
+ if (storageLevel.isValid) {
+ // isValid means it is either stored in-memory or on-disk.
+ _blocks.put(blockId, storageLevel)
if (storageLevel.useMemory) {
- remainedMem -= size
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
- remainedDisk -= size
+ logInfo("Added %s on disk on %s:%d (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val originalLevel: StorageLevel = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
+ logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
+ }
+ if (originalLevel.useDisk) {
+ logInfo("Removed %s on %s:%d on disk (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
}
- } else {
- blocks.remove(blockId)
}
}
- def getLastSeenMs: Long = {
- return lastSeenMs
- }
-
- def getRemainedMem: Long = {
- return remainedMem
- }
-
- def getRemainedDisk: Long = {
- return remainedDisk
- }
+ def remainingMem: Long = _remainingMem
- override def toString: String = {
- return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk
- }
+ def lastSeenMs: Long = _lastSeenMs
+
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
def clear() {
- blocks.clear()
+ _blocks.clear()
}
}
@@ -159,7 +167,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
initLogging()
-
+
def removeHost(host: String) {
logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
@@ -171,8 +179,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
def receive = {
- case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
- register(blockManagerId, maxMemSize, maxDiskSize)
+ case RegisterBlockManager(blockManagerId, maxMemSize) =>
+ register(blockManagerId, maxMemSize)
case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
@@ -186,7 +194,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
case GetPeers(blockManagerId, size) =>
getPeersDeterministic(blockManagerId, size)
/*getPeers(blockManagerId, size)*/
-
+
+ case GetMemoryStatus =>
+ getMemoryStatus
+
case RemoveHost(host) =>
removeHost(host)
sender ! true
@@ -196,43 +207,50 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
context.stop(self)
- case other =>
+ case other =>
logInfo("Got unknown message: " + other)
}
-
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
+
+ // Return a map from the block manager id to max memory and remaining memory.
+ private def getMemoryStatus() {
+ val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ (blockManagerId, (info.maxMem, info.remainingMem))
+ }.toMap
+ sender ! res
+ }
+
+ private def register(blockManagerId: BlockManagerId, maxMemSize: Long) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " "
logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- logInfo("Got Register Msg from " + blockManagerId)
if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
logInfo("Got Register Msg from master node, don't register it")
} else {
blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
- System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
+ blockManagerId, System.currentTimeMillis() / 1000, maxMemSize))
}
logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
sender ! true
}
-
+
private def heartBeat(
blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
- deserializedSize: Long,
- size: Long) {
-
+ memSize: Long,
+ diskSize: Long) {
+
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " " + blockId + " "
-
+
if (blockId == null) {
blockManagerInfo(blockManagerId).updateLastSeenMs()
logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
sender ! true
}
-
- blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
-
+
+ blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+
var locations: HashSet[BlockManagerId] = null
if (blockInfo.containsKey(blockId)) {
locations = blockInfo.get(blockId)._2
@@ -240,19 +258,19 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
locations = new HashSet[BlockManagerId]
blockInfo.put(blockId, (storageLevel.replication, locations))
}
-
+
if (storageLevel.isValid) {
locations += blockManagerId
} else {
locations.remove(blockManagerId)
}
-
+
if (locations.size == 0) {
blockInfo.remove(blockId)
}
sender ! true
}
-
+
private def getLocations(blockId: String) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockId + " "
@@ -260,7 +278,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
if (blockInfo.containsKey(blockId)) {
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
+ logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
+ Utils.getUsedTimeMs(startTimeMs))
sender ! res.toSeq
} else {
@@ -269,7 +287,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! res
}
}
-
+
private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
def getLocations(blockId: String): Seq[BlockManagerId] = {
val tmp = blockId
@@ -285,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
return res.toSeq
}
}
-
+
logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
for (blockId <- blockIds) {
@@ -306,7 +324,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
sender ! res.toSeq
}
-
+
private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
@@ -329,7 +347,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
}
-class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean)
+private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean)
extends Logging {
val AKKA_ACTOR_NAME: String = "BlockMasterManager"
@@ -352,7 +370,7 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
logInfo("Connecting to BlockManagerMaster: " + url)
masterActor = actorSystem.actorFor(url)
}
-
+
def stop() {
if (masterActor != null) {
communicate(StopBlockManagerMaster)
@@ -379,17 +397,19 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
throw new SparkException("Error reply received from BlockManagerMaster")
}
}
-
+
def notifyADeadHost(host: String) {
communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT))
logInfo("Removed " + host + " successfully in notifyADeadHost")
}
def mustRegisterBlockManager(msg: RegisterBlockManager) {
+ logInfo("Trying to register BlockManager")
while (! syncRegisterBlockManager(msg)) {
logWarning("Failed to register " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
}
+ logInfo("Done registering BlockManager")
}
def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
@@ -397,7 +417,7 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
+
try {
communicate(msg)
logInfo("BlockManager registered successfully @ syncRegisterBlockManager")
@@ -409,19 +429,19 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
return false
}
}
-
+
def mustHeartBeat(msg: HeartBeat) {
while (! syncHeartBeat(msg)) {
logWarning("Failed to send heartbeat" + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
}
}
-
+
def syncHeartBeat(msg: HeartBeat): Boolean = {
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
-
+
try {
communicate(msg)
logDebug("Heartbeat sent successfully")
@@ -433,7 +453,7 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
return false
}
}
-
+
def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
var res = syncGetLocations(msg)
while (res == null) {
@@ -443,7 +463,7 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
}
return res
}
-
+
def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
val startTimeMs = System.currentTimeMillis()
val tmp = " msg " + msg + " "
@@ -476,13 +496,13 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
}
return res
}
-
+
def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
Seq[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val tmp = " msg " + msg + " "
logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
+
try {
val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]]
if (answer != null) {
@@ -500,7 +520,7 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
return null
}
}
-
+
def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
var res = syncGetPeers(msg)
while ((res == null) || (res.length != msg.size)) {
@@ -508,10 +528,10 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
res = syncGetPeers(msg)
}
-
+
return res
}
-
+
def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
val startTimeMs = System.currentTimeMillis
val tmp = " msg " + msg + " "
@@ -533,4 +553,8 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
return null
}
}
+
+ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]]
+ }
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index 0658a57187..d2985559c1 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -1,25 +1,23 @@
package spark.storage
-import java.nio._
+import java.nio.ByteBuffer
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
-import spark.Logging
-import spark.Utils
-import spark.SparkEnv
+import spark.{Logging, Utils, SparkEnv}
import spark.network._
/**
- * This should be changed to use event model late.
+ * A network interface for BlockManager. Each slave should have one
+ * BlockManagerWorker.
+ *
+ * TODO: Use event model.
*/
-class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
+private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
@@ -32,11 +30,10 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
logDebug("Handling as a buffer message " + bufferMessage)
val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
logDebug("Parsed as a block message array")
- val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
- /*logDebug("Processed block messages")*/
+ val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
return Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
+ case e: Exception => logError("Exception handling buffer message", e)
return None
}
}
@@ -51,13 +48,13 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
blockMessage.getType match {
case BlockMessage.TYPE_PUT_BLOCK => {
val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
- logInfo("Received [" + pB + "]")
+ logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
- logInfo("Received [" + gB + "]")
+ logDebug("Received [" + gB + "]")
val buffer = getBlock(gB.id)
if (buffer == null) {
return None
@@ -78,17 +75,10 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
private def getBlock(id: String): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
- logDebug("Getblock " + id + " started from " + startTimeMs)
- val block = blockManager.getLocal(id)
- val buffer = block match {
- case Some(tValues) => {
- val values = tValues
- val buffer = blockManager.dataSerialize(values)
- buffer
- }
- case None => {
- null
- }
+ logDebug("GetBlock " + id + " started from " + startTimeMs)
+ val buffer = blockManager.getLocalBytes(id) match {
+ case Some(bytes) => bytes
+ case None => null
}
logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
@@ -96,7 +86,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
}
}
-object BlockManagerWorker extends Logging {
+private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
index 607633c6df..3f234df654 100644
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -1,6 +1,6 @@
package spark.storage
-import java.nio._
+import java.nio.ByteBuffer
import scala.collection.mutable.StringBuilder
import scala.collection.mutable.ArrayBuffer
@@ -8,11 +8,11 @@ import scala.collection.mutable.ArrayBuffer
import spark._
import spark.network._
-case class GetBlock(id: String)
-case class GotBlock(id: String, data: ByteBuffer)
-case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+private[spark] case class GetBlock(id: String)
+private[spark] case class GotBlock(id: String, data: ByteBuffer)
+private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
-class BlockMessage() {
+private[spark] class BlockMessage() {
// Un-initialized: typ = 0
// GetBlock: typ = 1
// GotBlock: typ = 2
@@ -158,7 +158,7 @@ class BlockMessage() {
}
}
-object BlockMessage {
+private[spark] object BlockMessage {
val TYPE_NON_INITIALIZED: Int = 0
val TYPE_GET_BLOCK: Int = 1
val TYPE_GOT_BLOCK: Int = 2
@@ -196,7 +196,7 @@ object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
+ B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
index 497a19856e..a25decb123 100644
--- a/core/src/main/scala/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -1,5 +1,6 @@
package spark.storage
-import java.nio._
+
+import java.nio.ByteBuffer
import scala.collection.mutable.StringBuilder
import scala.collection.mutable.ArrayBuffer
@@ -7,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
import spark._
import spark.network._
+private[spark]
class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
def this(bm: BlockMessage) = this(Array(bm))
@@ -84,7 +86,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM
}
}
-object BlockMessageArray {
+private[spark] object BlockMessageArray {
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
@@ -98,7 +100,7 @@ object BlockMessageArray {
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
+ BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
} else {
BlockMessage.fromGetBlock(GetBlock(i.toString))
}
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
index 77e0ed84c5..096bf8bdd9 100644
--- a/core/src/main/scala/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -1,25 +1,31 @@
package spark.storage
-import spark.{Utils, Logging, Serializer, SizeEstimator}
-import scala.collection.mutable.ArrayBuffer
-import java.io.{File, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel.MapMode
-import java.util.{UUID, LinkedHashMap}
-import java.util.concurrent.Executors
-import java.util.concurrent.ConcurrentHashMap
-import it.unimi.dsi.fastutil.io._
-import java.util.concurrent.ArrayBlockingQueue
+import scala.collection.mutable.ArrayBuffer
+
+import spark.Logging
/**
* Abstract class to store blocks
*/
-abstract class BlockStore(blockManager: BlockManager) extends Logging {
- initLogging()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
+private[spark]
+abstract class BlockStore(val blockManager: BlockManager) extends Logging {
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+
+ /**
+ * Put in a block and, possibly, also return its content as either bytes or another Iterator.
+ * This is used to efficiently write the values to multiple locations (e.g. for replication).
+ *
+ * @return a PutResult that contains the size of the data, as well as the values put if
+ * returnValues is true (if not, the result's data field can be null)
+ */
+ def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ returnValues: Boolean) : PutResult
+
+ /**
+ * Return the size of a block in bytes.
+ */
+ def getSize(blockId: String): Long
def getBytes(blockId: String): Option[ByteBuffer]
@@ -27,282 +33,7 @@ abstract class BlockStore(blockManager: BlockManager) extends Logging {
def remove(blockId: String)
- def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
-
- def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
+ def contains(blockId: String): Boolean
def clear() { }
}
-
-/**
- * Class to store blocks in memory
- */
-class MemoryStore(blockManager: BlockManager, maxMemory: Long)
- extends BlockStore(blockManager) {
-
- case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false)
-
- private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
- private var currentMemory = 0L
-
- //private val blockDropper = Executors.newSingleThreadExecutor()
- private val blocksToDrop = new ArrayBlockingQueue[String](10000, true)
- private val blockDropper = new Thread("memory store - block dropper") {
- override def run() {
- try{
- while (true) {
- val blockId = blocksToDrop.take()
- logDebug("Block " + blockId + " ready to be dropped")
- blockManager.dropFromMemory(blockId)
- }
- } catch {
- case ie: InterruptedException =>
- logInfo("Shutting down block dropper")
- }
- }
- }
- blockDropper.start()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
- if (level.deserialized) {
- bytes.rewind()
- val values = dataDeserialize(bytes)
- val elements = new ArrayBuffer[Any]
- elements ++= values
- val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
- ensureFreeSpace(sizeEstimate)
- val entry = new Entry(elements, sizeEstimate, true)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += sizeEstimate
- logDebug("Block " + blockId + " stored as values to memory")
- } else {
- val entry = new Entry(bytes, bytes.limit, false)
- ensureFreeSpace(bytes.limit)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += bytes.limit
- logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory")
- }
- }
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
- if (level.deserialized) {
- val elements = new ArrayBuffer[Any]
- elements ++= values
- val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
- ensureFreeSpace(sizeEstimate)
- val entry = new Entry(elements, sizeEstimate, true)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += sizeEstimate
- logDebug("Block " + blockId + " stored as values to memory")
- return Left(elements.iterator)
- } else {
- val bytes = dataSerialize(values)
- ensureFreeSpace(bytes.limit)
- val entry = new Entry(bytes, bytes.limit, false)
- memoryStore.synchronized { memoryStore.put(blockId, entry) }
- currentMemory += bytes.limit
- logDebug("Block " + blockId + " stored as " + bytes.limit + " bytes to memory")
- return Right(bytes)
- }
- }
-
- def getBytes(blockId: String): Option[ByteBuffer] = {
- throw new UnsupportedOperationException("Not implemented")
- }
-
- def getValues(blockId: String): Option[Iterator[Any]] = {
- val entry = memoryStore.synchronized { memoryStore.get(blockId) }
- if (entry == null) {
- return None
- }
- if (entry.deserialized) {
- return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
- } else {
- return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate()))
- }
- }
-
- def remove(blockId: String) {
- memoryStore.synchronized {
- val entry = memoryStore.get(blockId)
- if (entry != null) {
- memoryStore.remove(blockId)
- currentMemory -= entry.size
- logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
- } else {
- logWarning("Block " + blockId + " could not be removed as it doesnt exist")
- }
- }
- }
-
- override def clear() {
- memoryStore.synchronized {
- memoryStore.clear()
- }
- //blockDropper.shutdown()
- blockDropper.interrupt()
- logInfo("MemoryStore cleared")
- }
-
- private def ensureFreeSpace(space: Long) {
- logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
- space, currentMemory, maxMemory))
-
- if (maxMemory - currentMemory < space) {
-
- val selectedBlocks = new ArrayBuffer[String]()
- var selectedMemory = 0L
-
- memoryStore.synchronized {
- val iter = memoryStore.entrySet().iterator()
- while (maxMemory - (currentMemory - selectedMemory) < space && iter.hasNext) {
- val pair = iter.next()
- val blockId = pair.getKey
- val entry = pair.getValue()
- if (!entry.dropPending) {
- selectedBlocks += blockId
- entry.dropPending = true
- }
- selectedMemory += pair.getValue.size
- logDebug("Block " + blockId + " selected for dropping")
- }
- }
-
- logDebug("" + selectedBlocks.size + " new blocks selected for dropping, " +
- blocksToDrop.size + " blocks pending")
- var i = 0
- while (i < selectedBlocks.size) {
- blocksToDrop.add(selectedBlocks(i))
- i += 1
- }
- selectedBlocks.clear()
- }
- }
-}
-
-
-/**
- * Class to store blocks in disk
- */
-class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
-
- val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- val localDirs = createLocalDirs()
- var lastLocalDirUsed = 0
-
- addShutdownHook()
-
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
- logDebug("Attempting to put block " + blockId)
- val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- if (file != null) {
- val channel = new RandomAccessFile(file, "rw").getChannel()
- val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit)
- buffer.put(bytes)
- channel.close()
- val finishTime = System.currentTimeMillis
- logDebug("Block " + blockId + " stored to file of " + bytes.limit + " bytes to disk in " + (finishTime - startTime) + " ms")
- } else {
- logError("File not created for block " + blockId)
- }
- }
-
- def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
- val bytes = dataSerialize(values)
- logDebug("Converted block " + blockId + " to " + bytes.limit + " bytes")
- putBytes(blockId, bytes, level)
- return Right(bytes)
- }
-
- def getBytes(blockId: String): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- Some(channel.map(MapMode.READ_WRITE, 0, length))
- }
-
- def getValues(blockId: String): Option[Iterator[Any]] = {
- val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = channel.map(MapMode.READ_ONLY, 0, length)
- val buffer = dataDeserialize(bytes)
- channel.close()
- return Some(buffer)
- }
-
- def remove(blockId: String) {
- throw new UnsupportedOperationException("Not implemented")
- }
-
- private def createFile(blockId: String): File = {
- val file = getFile(blockId)
- if (file == null) {
- lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
- val newFile = new File(localDirs(lastLocalDirUsed), blockId)
- newFile.getParentFile.mkdirs()
- return newFile
- } else {
- logError("File for block " + blockId + " already exists on disk, " + file)
- return null
- }
- }
-
- private def getFile(blockId: String): File = {
- logDebug("Getting file for block " + blockId)
- // Search for the file in all the local directories, only one of them should have the file
- val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)
- if (files.size > 1) {
- throw new Exception("Multiple files for same block " + blockId + " exists: " +
- files.map(_.toString).reduceLeft(_ + ", " + _))
- return null
- } else if (files.size == 0) {
- return null
- } else {
- logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
- return files(0)
- }
- }
-
- private def createLocalDirs(): Seq[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- rootDirs.split("[;,:]").map(rootDir => {
- var foundLocalDir: Boolean = false
- var localDir: File = null
- var localDirUuid: UUID = null
- var tries = 0
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirUuid = UUID.randomUUID()
- localDir = new File(rootDir, "spark-local-" + localDirUuid)
- if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(1)
- }
- logDebug("Created local directory at " + localDir)
- localDir
- })
- }
-
- private def addShutdownHook() {
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
- }
- })
- }
-}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
new file mode 100644
index 0000000000..8ba64e4b76
--- /dev/null
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -0,0 +1,180 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.nio.channels.FileChannel.MapMode
+import java.util.{Random, Date}
+import java.text.SimpleDateFormat
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.Utils
+
+/**
+ * Stores BlockManager blocks on disk.
+ */
+private class DiskStore(blockManager: BlockManager, rootDirs: String)
+ extends BlockStore(blockManager) {
+
+ val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ val localDirs = createLocalDirs()
+ val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+
+ addShutdownHook()
+
+ override def getSize(blockId: String): Long = {
+ getFile(blockId).length()
+ }
+
+ override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ logDebug("Attempting to put block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ val channel = new RandomAccessFile(file, "rw").getChannel()
+ while (bytes.remaining > 0) {
+ channel.write(bytes)
+ }
+ channel.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block %s stored as %s file on disk in %d ms".format(
+ blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
+ }
+
+ override def putValues(
+ blockId: String,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ logDebug("Attempting to write values for block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ val fileOut = blockManager.wrapForCompression(blockId,
+ new FastBufferedOutputStream(new FileOutputStream(file)))
+ val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ objOut.writeAll(values.iterator)
+ objOut.close()
+ val length = file.length()
+ logDebug("Block %s stored as %s file on disk in %d ms".format(
+ blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+
+ if (returnValues) {
+ // Return a byte buffer for the contents of the file
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val buffer = channel.map(MapMode.READ_ONLY, 0, length)
+ channel.close()
+ PutResult(length, Right(buffer))
+ } else {
+ PutResult(length, null)
+ }
+ }
+
+ override def getBytes(blockId: String): Option[ByteBuffer] = {
+ val file = getFile(blockId)
+ val length = file.length().toInt
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val bytes = channel.map(MapMode.READ_ONLY, 0, length)
+ channel.close()
+ Some(bytes)
+ }
+
+ override def getValues(blockId: String): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ }
+
+ override def remove(blockId: String) {
+ val file = getFile(blockId)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+
+ override def contains(blockId: String): Boolean = {
+ getFile(blockId).exists()
+ }
+
+ private def createFile(blockId: String): File = {
+ val file = getFile(blockId)
+ if (file.exists()) {
+ throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ }
+ file
+ }
+
+ private def getFile(blockId: String): File = {
+ logDebug("Getting file for block " + blockId)
+
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, blockId)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map(rootDir => {
+ var foundLocalDir: Boolean = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(1)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ })
+ }
+
+ private def addShutdownHook() {
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
new file mode 100644
index 0000000000..773970446a
--- /dev/null
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -0,0 +1,218 @@
+package spark.storage
+
+import java.util.LinkedHashMap
+import java.util.concurrent.ArrayBlockingQueue
+import spark.{SizeEstimator, Utils}
+import java.nio.ByteBuffer
+import collection.mutable.ArrayBuffer
+
+/**
+ * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as
+ * serialized ByteBuffers.
+ */
+private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
+ extends BlockStore(blockManager) {
+
+ case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false)
+
+ private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private var currentMemory = 0L
+
+ logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory)))
+
+ def freeMemory: Long = maxMemory - currentMemory
+
+ override def getSize(blockId: String): Long = {
+ synchronized {
+ entries.get(blockId).size
+ }
+ }
+
+ override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ if (level.deserialized) {
+ bytes.rewind()
+ val values = blockManager.dataDeserialize(blockId, bytes)
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ tryToPut(blockId, elements, sizeEstimate, true)
+ } else {
+ val entry = new Entry(bytes, bytes.limit, false)
+ ensureFreeSpace(blockId, bytes.limit)
+ synchronized { entries.put(blockId, entry) }
+ tryToPut(blockId, bytes, bytes.limit, false)
+ }
+ }
+
+ override def putValues(
+ blockId: String,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ if (level.deserialized) {
+ val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
+ tryToPut(blockId, values, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(values.iterator))
+ } else {
+ val bytes = blockManager.dataSerialize(blockId, values.iterator)
+ tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes))
+ }
+ }
+
+ override def getBytes(blockId: String): Option[ByteBuffer] = {
+ val entry = synchronized {
+ entries.get(blockId)
+ }
+ if (entry == null) {
+ None
+ } else if (entry.deserialized) {
+ Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator))
+ } else {
+ Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
+ }
+ }
+
+ override def getValues(blockId: String): Option[Iterator[Any]] = {
+ val entry = synchronized {
+ entries.get(blockId)
+ }
+ if (entry == null) {
+ None
+ } else if (entry.deserialized) {
+ Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)
+ } else {
+ val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data
+ Some(blockManager.dataDeserialize(blockId, buffer))
+ }
+ }
+
+ override def remove(blockId: String) {
+ synchronized {
+ val entry = entries.get(blockId)
+ if (entry != null) {
+ entries.remove(blockId)
+ currentMemory -= entry.size
+ logInfo("Block %s of size %d dropped from memory (free %d)".format(
+ blockId, entry.size, freeMemory))
+ } else {
+ logWarning("Block " + blockId + " could not be removed as it does not exist")
+ }
+ }
+ }
+
+ override def clear() {
+ synchronized {
+ entries.clear()
+ }
+ logInfo("MemoryStore cleared")
+ }
+
+ /**
+ * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ */
+ private def getRddId(blockId: String): String = {
+ if (blockId.startsWith("rdd_")) {
+ blockId.split('_')(1)
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Try to put in a set of values, if we can free up enough space. The value should either be
+ * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated)
+ * size must also be passed by the caller.
+ */
+ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ synchronized {
+ if (ensureFreeSpace(blockId, size)) {
+ val entry = new Entry(value, size, deserialized)
+ entries.put(blockId, entry)
+ currentMemory += size
+ if (deserialized) {
+ logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format(
+ blockId, Utils.memoryBytesToString(size), Utils.memoryBytesToString(freeMemory)))
+ } else {
+ logInfo("Block %s stored as bytes to memory (size %s, free %s)".format(
+ blockId, Utils.memoryBytesToString(size), Utils.memoryBytesToString(freeMemory)))
+ }
+ true
+ } else {
+ // Tell the block manager that we couldn't put it in memory so that it can drop it to
+ // disk if the block allows disk storage.
+ val data = if (deserialized) {
+ Left(value.asInstanceOf[ArrayBuffer[Any]])
+ } else {
+ Right(value.asInstanceOf[ByteBuffer].duplicate())
+ }
+ blockManager.dropFromMemory(blockId, data)
+ false
+ }
+ }
+ }
+
+ /**
+ * Tries to free up a given amount of space to store a particular block, but can fail and return
+ * false if either the block is bigger than our memory or it would require replacing another
+ * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
+ * don't fit into memory that we want to avoid).
+ *
+ * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space
+ * might fill up before the caller puts in their new value.)
+ */
+ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+ logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
+ space, currentMemory, maxMemory))
+
+ if (space > maxMemory) {
+ logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit")
+ return false
+ }
+
+ // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks
+ // in order to allow parallelism in writing to disk
+ if (maxMemory - currentMemory < space) {
+ val rddToAdd = getRddId(blockIdToAdd)
+ val selectedBlocks = new ArrayBuffer[String]()
+ var selectedMemory = 0L
+
+ val iterator = entries.entrySet().iterator()
+ while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
+ val pair = iterator.next()
+ val blockId = pair.getKey
+ if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
+ "block from the same RDD")
+ return false
+ }
+ selectedBlocks += blockId
+ selectedMemory += pair.getValue.size
+ }
+
+ if (maxMemory - (currentMemory - selectedMemory) >= space) {
+ logInfo(selectedBlocks.size + " blocks selected for dropping")
+ for (blockId <- selectedBlocks) {
+ val entry = entries.get(blockId)
+ val data = if (entry.deserialized) {
+ Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
+ } else {
+ Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
+ }
+ blockManager.dropFromMemory(blockId, data)
+ }
+ return true
+ } else {
+ return false
+ }
+ }
+ return true
+ }
+
+ override def contains(blockId: String): Boolean = {
+ synchronized { entries.containsKey(blockId) }
+ }
+}
+
diff --git a/core/src/main/scala/spark/storage/PutResult.scala b/core/src/main/scala/spark/storage/PutResult.scala
new file mode 100644
index 0000000000..76f236057b
--- /dev/null
+++ b/core/src/main/scala/spark/storage/PutResult.scala
@@ -0,0 +1,9 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+/**
+ * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the
+ * values put if the caller asked for them to be returned (e.g. for chaining replication)
+ */
+private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer])
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index a64393eba7..c497f03e0c 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -1,7 +1,14 @@
package spark.storage
-import java.io._
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+/**
+ * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
+ * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
+ * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
+ * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
+ * commonly useful storage levels.
+ */
class StorageLevel(
var useDisk: Boolean,
var useMemory: Boolean,
@@ -67,12 +74,12 @@ object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
val DISK_ONLY_2 = new StorageLevel(true, false, false, 2)
- val MEMORY_ONLY = new StorageLevel(false, true, false)
- val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
- val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
- val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
- val DISK_AND_MEMORY = new StorageLevel(true, true, false)
- val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
- val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
- val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
+ val MEMORY_ONLY = new StorageLevel(false, true, true)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2)
+ val MEMORY_ONLY_SER = new StorageLevel(false, true, false)
+ val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2)
+ val MEMORY_AND_DISK = new StorageLevel(true, true, true)
+ val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
+ val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
+ val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index df4e23bfd6..b466b5239c 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -17,12 +17,14 @@ import java.util.concurrent.TimeoutException
/**
* Various utility classes for working with Akka.
*/
-object AkkaUtils {
+private[spark] object AkkaUtils {
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
+ val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
+ val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@@ -31,9 +33,9 @@ object AkkaUtils {
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = 1s
- akka.remote.netty.execution-pool-size = 8
- akka.actor.default-dispatcher.throughput = 30
- """.format(host, port))
+ akka.remote.netty.execution-pool-size = %d
+ akka.actor.default-dispatcher.throughput = %d
+ """.format(host, port, akkaThreads, akkaBatchSize))
val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
index 0ce255105a..d7e67497fe 100644
--- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
@@ -2,10 +2,19 @@ package spark.util
import java.io.InputStream
import java.nio.ByteBuffer
+import spark.storage.BlockManager
+
+/**
+ * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose()
+ * at the end of the stream (e.g. to close a memory-mapped file).
+ */
+private[spark]
+class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false)
+ extends InputStream {
-class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
override def read(): Int = {
- if (buffer.remaining() == 0) {
+ if (buffer == null || buffer.remaining() == 0) {
+ cleanUp()
-1
} else {
buffer.get() & 0xFF
@@ -17,7 +26,8 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
}
override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
- if (buffer.remaining() == 0) {
+ if (buffer == null || buffer.remaining() == 0) {
+ cleanUp()
-1
} else {
val amountToGet = math.min(buffer.remaining(), length)
@@ -27,8 +37,27 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
}
override def skip(bytes: Long): Long = {
- val amountToSkip = math.min(bytes, buffer.remaining).toInt
- buffer.position(buffer.position + amountToSkip)
- return amountToSkip
+ if (buffer != null) {
+ val amountToSkip = math.min(bytes, buffer.remaining).toInt
+ buffer.position(buffer.position + amountToSkip)
+ if (buffer.remaining() == 0) {
+ cleanUp()
+ }
+ amountToSkip
+ } else {
+ 0L
+ }
+ }
+
+ /**
+ * Clean up the buffer, and potentially dispose of it using BlockManager.dispose().
+ */
+ private def cleanUp() {
+ if (buffer != null) {
+ if (dispose) {
+ BlockManager.dispose(buffer)
+ }
+ buffer = null
+ }
}
}
diff --git a/core/src/main/scala/spark/util/IntParam.scala b/core/src/main/scala/spark/util/IntParam.scala
index c3ff063569..0427646747 100644
--- a/core/src/main/scala/spark/util/IntParam.scala
+++ b/core/src/main/scala/spark/util/IntParam.scala
@@ -3,7 +3,7 @@ package spark.util
/**
* An extractor object for parsing strings into integers.
*/
-object IntParam {
+private[spark] object IntParam {
def unapply(str: String): Option[Int] = {
try {
Some(str.toInt)
diff --git a/core/src/main/scala/spark/util/MemoryParam.scala b/core/src/main/scala/spark/util/MemoryParam.scala
index 4fba914afe..3726738842 100644
--- a/core/src/main/scala/spark/util/MemoryParam.scala
+++ b/core/src/main/scala/spark/util/MemoryParam.scala
@@ -6,7 +6,7 @@ import spark.Utils
* An extractor object for parsing JVM memory strings, such as "10g", into an Int representing
* the number of megabytes. Supports the same formats as Utils.memoryStringToMb.
*/
-object MemoryParam {
+private[spark] object MemoryParam {
def unapply(str: String): Option[Int] = {
try {
Some(Utils.memoryStringToMb(str))
diff --git a/core/src/main/scala/spark/util/SerializableBuffer.scala b/core/src/main/scala/spark/util/SerializableBuffer.scala
index 0830843a77..09d588fe1c 100644
--- a/core/src/main/scala/spark/util/SerializableBuffer.scala
+++ b/core/src/main/scala/spark/util/SerializableBuffer.scala
@@ -8,6 +8,7 @@ import java.nio.channels.Channels
* A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make
* it easier to pass ByteBuffers in case class messages.
*/
+private[spark]
class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable {
def value = buffer
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
index 11d7939204..5f80180339 100644
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -2,8 +2,10 @@ package spark.util
/**
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
- * numerically robust way. Includes support for merging two StatCounters. Based on Welford and
- * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
+ * numerically robust way. Includes support for merging two StatCounters. Based on
+ * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]].
+ *
+ * @constructor Initialize the StatCounter with the given values.
*/
class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
@@ -12,8 +14,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
merge(values)
+ /** Initialize the StatCounter with no values. */
def this() = this(Nil)
+ /** Add a value into this StatCounter, updating the internal statistics. */
def merge(value: Double): StatCounter = {
val delta = value - mu
n += 1
@@ -22,11 +26,13 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
this
}
+ /** Add multiple values into this StatCounter, updating the internal statistics. */
def merge(values: TraversableOnce[Double]): StatCounter = {
values.foreach(v => merge(v))
this
}
+ /** Merge another StatCounter into this one, adding up the internal statistics. */
def merge(other: StatCounter): StatCounter = {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
@@ -45,6 +51,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
}
+ /** Clone this StatCounter */
def copy(): StatCounter = {
val other = new StatCounter
other.n = n
@@ -59,6 +66,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def sum: Double = n * mu
+ /** Return the variance of the values. */
def variance: Double = {
if (n == 0)
Double.NaN
@@ -66,6 +74,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
m2 / n
}
+ /**
+ * Return the sample variance, which corrects for bias in estimating the variance by dividing
+ * by N-1 instead of N.
+ */
def sampleVariance: Double = {
if (n <= 1)
Double.NaN
@@ -73,8 +85,13 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
m2 / (n - 1)
}
+ /** Return the standard deviation of the values. */
def stdev: Double = math.sqrt(variance)
+ /**
+ * Return the sample standard deviation of the values, which corrects for bias in estimating the
+ * variance by dividing by N-1 instead of N.
+ */
def sampleStdev: Double = math.sqrt(sampleVariance)
override def toString: String = {
@@ -83,7 +100,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
object StatCounter {
+ /** Build a StatCounter from a list of values. */
def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+ /** Build a StatCounter from a list of values passed as variable-length arguments. */
def apply(values: Double*) = new StatCounter(values)
}
diff --git a/core/src/main/twirl/common/layout.scala.html b/core/src/main/twirl/spark/deploy/common/layout.scala.html
index b9192060aa..b9192060aa 100644
--- a/core/src/main/twirl/common/layout.scala.html
+++ b/core/src/main/twirl/spark/deploy/common/layout.scala.html
diff --git a/core/src/main/twirl/masterui/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index 784d692fc2..784d692fc2 100644
--- a/core/src/main/twirl/masterui/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
diff --git a/core/src/main/twirl/masterui/executors_table.scala.html b/core/src/main/twirl/spark/deploy/master/executors_table.scala.html
index cafc42c80e..cafc42c80e 100644
--- a/core/src/main/twirl/masterui/executors_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executors_table.scala.html
diff --git a/core/src/main/twirl/masterui/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index 31ca8f4132..7562076b00 100644
--- a/core/src/main/twirl/masterui/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -1,13 +1,13 @@
@(state: spark.deploy.MasterState)
@import spark.deploy.master._
-@common.html.layout(title = "Spark Master on " + state.uri) {
+@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) {
<!-- Cluster Details -->
<div class="row">
<div class="span12">
<ul class="unstyled">
- <li><strong>URI:</strong> spark://@(state.uri)</li>
+ <li><strong>URL:</strong> spark://@(state.uri)</li>
<li><strong>Number of Workers:</strong> @state.workers.size </li>
<li><strong>Cores:</strong> @state.workers.map(_.cores).sum Total, @state.workers.map(_.coresUsed).sum Used</li>
<li><strong>Memory:</strong> @state.workers.map(_.memory).sum Total, @state.workers.map(_.memoryUsed).sum Used</li>
@@ -47,4 +47,4 @@
</div>
</div>
-} \ No newline at end of file
+}
diff --git a/core/src/main/twirl/masterui/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
index 73cefb8269..dcf41c28f2 100644
--- a/core/src/main/twirl/masterui/job_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
@@ -1,6 +1,6 @@
@(job: spark.deploy.master.JobInfo)
-@common.html.layout(title = "Job Details") {
+@spark.deploy.common.html.layout(title = "Job Details") {
<!-- Job Details -->
<div class="row">
@@ -37,4 +37,4 @@
</div>
</div>
-} \ No newline at end of file
+}
diff --git a/core/src/main/twirl/masterui/job_row.scala.html b/core/src/main/twirl/spark/deploy/master/job_row.scala.html
index 7c4865bb6e..7c4865bb6e 100644
--- a/core/src/main/twirl/masterui/job_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/job_row.scala.html
diff --git a/core/src/main/twirl/masterui/job_table.scala.html b/core/src/main/twirl/spark/deploy/master/job_table.scala.html
index 52bad6c4b8..52bad6c4b8 100644
--- a/core/src/main/twirl/masterui/job_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/job_table.scala.html
diff --git a/core/src/main/twirl/masterui/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index b21bd9c977..017cc4859e 100644
--- a/core/src/main/twirl/masterui/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -4,7 +4,8 @@
<td>
<a href="http://@worker.host:@worker.webUiPort">@worker.id</href>
</td>
- <td>@worker.host:@worker.port</td>
+ <td>@{worker.host}:@{worker.port}</td>
<td>@worker.cores (@worker.coresUsed Used)</td>
- <td>@worker.memory (@worker.memoryUsed Used)</td>
-</tr> \ No newline at end of file
+ <td>@{spark.Utils.memoryMegabytesToString(worker.memory)}
+ (@{spark.Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td>
+</tr>
diff --git a/core/src/main/twirl/masterui/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
index 2028842297..2028842297 100644
--- a/core/src/main/twirl/masterui/worker_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
diff --git a/core/src/main/twirl/workerui/executor_row.scala.html b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
index c3842dbf85..c3842dbf85 100644
--- a/core/src/main/twirl/workerui/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
diff --git a/core/src/main/twirl/workerui/executors_table.scala.html b/core/src/main/twirl/spark/deploy/worker/executors_table.scala.html
index 327a2399c7..327a2399c7 100644
--- a/core/src/main/twirl/workerui/executors_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/executors_table.scala.html
diff --git a/core/src/main/twirl/workerui/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index edd82e02f2..69746ed02c 100644
--- a/core/src/main/twirl/workerui/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,6 +1,6 @@
@(worker: spark.deploy.WorkerState)
-@common.html.layout(title = "Spark Worker on " + worker.uri) {
+@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) {
<!-- Worker Details -->
<div class="row">
@@ -12,7 +12,8 @@
(WebUI at <a href="@worker.masterWebUiUrl">@worker.masterWebUiUrl</a>)
</li>
<li><strong>Cores:</strong> @worker.cores (@worker.coresUsed Used)</li>
- <li><strong>Memory:</strong> @worker.memory (@worker.memoryUsed Used)</li>
+ <li><strong>Memory:</strong> @{spark.Utils.memoryMegabytesToString(worker.memory)}
+ (@{spark.Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li>
</ul>
</div>
</div>
@@ -39,4 +40,4 @@
</div>
</div>
-} \ No newline at end of file
+}