aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-01-13 16:21:26 -0800
committerReynold Xin <rxin@apache.org>2014-01-13 16:21:26 -0800
commite2d25d2dfeb1d43d1e36f169250d8efef4ac232a (patch)
treed911a37f5aacc89bc3a1c76d41842e1c156aec6a /core
parent8038da232870fe016e73122a2ef110ac8e56ca1e (diff)
parentb93f9d42f21f03163734ef97b2871db945e166da (diff)
downloadspark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.tar.gz
spark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.tar.bz2
spark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.zip
Merge branch 'master' into graphx
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml10
-rw-r--r--core/src/main/resources/org/apache/spark/log4j-defaults.properties5
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/Logging.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala169
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala151
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala117
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala (renamed from core/src/main/scala/org/apache/spark/deploy/client/Client.scala)13
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala (renamed from core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala195
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala56
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala234
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/io/CompressionCodec.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/Connection.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/SizeEstimator.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala (renamed from core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala)118
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala350
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala101
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala40
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala131
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala120
-rw-r--r--core/src/test/scala/org/apache/spark/util/VectorSuite.scala44
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala)46
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala230
122 files changed, 3167 insertions, 621 deletions
diff --git a/core/pom.xml b/core/pom.xml
index aac0a9d11e..9e5a450d57 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -99,6 +99,11 @@
<artifactId>akka-slf4j_${scala.binary.version}</artifactId>
</dependency>
<dependency>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-testkit_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
@@ -166,6 +171,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index d72dbadc39..f7f8535594 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -1,8 +1,11 @@
# Set everything to be logged to the console
log4j.rootCategory=INFO, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
+# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5f73d234aa..e89ac28b8e 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -218,7 +218,7 @@ private object Accumulators {
def newId: Long = synchronized {
lastId += 1
- return lastId
+ lastId
}
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 1a2ec55876..8b30cd4bfe 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.apache.spark.util.AppendOnlyMap
+import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
/**
* A set of functions used to aggregate data.
@@ -31,30 +31,51 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
+ private val sparkConf = SparkEnv.get.conf
+ private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
+
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kv: Product2[K, V] = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
- }
- while (iter.hasNext) {
- kv = iter.next()
- combiners.changeValue(kv._1, update)
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners =
+ new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ combiners.insert(k, v)
+ }
+ combiners.iterator
}
- combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kc: (K, C) = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kc: Product2[K, C] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, c) = iter.next()
+ combiners.insert(k, c)
+ }
+ combiners.iterator
}
- while (iter.hasNext) {
- kc = iter.next()
- combiners.changeValue(kc._1, update)
- }
- combiners.iterator
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 519ecde50a..8e5dd8a850 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
+ new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -74,7 +74,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val elements = new ArrayBuffer[Any]
elements ++= computedValues
blockManager.put(key, elements, storageLevel, tellMaster = true)
- return elements.iterator.asInstanceOf[Iterator[T]]
+ elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index ad1ee20045..a885898ad4 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -47,17 +47,17 @@ private[spark] class HttpFileServer extends Logging {
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
- return serverUri + "/files/" + file.getName
+ serverUri + "/files/" + file.getName
}
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
- return serverUri + "/jars/" + file.getName
+ serverUri + "/jars/" + file.getName
}
def addFileToDir(file: File, dir: File) : String = {
Files.copy(file, new File(dir, file.getName))
- return dir + "/" + file.getName
+ dir + "/" + file.getName
}
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 4a34989e50..9063cae87e 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -41,7 +41,7 @@ trait Logging {
}
log_ = LoggerFactory.getLogger(className)
}
- return log_
+ log_
}
// Log methods that take only a String
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 77b8ca1cce..30d182b008 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -32,15 +32,16 @@ import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
private[spark] sealed trait MapOutputTrackerMessage
-private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
+private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
- case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
- logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
+ case GetMapOutputStatuses(shuffleId: Int) =>
+ val hostPort = sender.path.address.hostPort
+ logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
@@ -119,11 +120,10 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val hostPort = Utils.localHostPort(conf)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
@@ -139,7 +139,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
- else{
+ else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 31b0773bfe..fc0a749882 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -53,15 +53,16 @@ object Partitioner {
return r.partitioner.get
}
if (rdd.context.conf.contains("spark.default.parallelism")) {
- return new HashPartitioner(rdd.context.defaultParallelism)
+ new HashPartitioner(rdd.context.defaultParallelism)
} else {
- return new HashPartitioner(bySize.head.partitions.size)
+ new HashPartitioner(bySize.head.partitions.size)
}
}
}
/**
- * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
+ * Java's `Object.hashCode`.
*
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
@@ -84,8 +85,8 @@ class HashPartitioner(partitions: Int) extends Partitioner {
}
/**
- * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
- * Determines the ranges by sampling the RDD passed in.
+ * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
+ * equal ranges. The ranges are determined by sampling the content of the RDD passed in.
*/
class RangePartitioner[K <% Ordered[K]: ClassTag, V](
partitions: Int,
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0e47f4e442..55ac76bf63 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -31,9 +31,9 @@ import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
-FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
+ FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat,
-TextInputFormat}
+ TextInputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
@@ -49,7 +49,7 @@ import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType,
-ClosureCleaner}
+ ClosureCleaner}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -116,7 +116,7 @@ class SparkContext(
throw new SparkException("An application must be set in your configuration")
}
- if (conf.get("spark.logConf", "false").toBoolean) {
+ if (conf.getBoolean("spark.logConf", false)) {
logInfo("Spark configuration:\n" + conf.toDebugString)
}
@@ -244,6 +244,10 @@ class SparkContext(
localProperties.set(new Properties())
}
+ /**
+ * Set a local property that affects jobs submitted from this thread, such as the
+ * Spark fair scheduler pool.
+ */
def setLocalProperty(key: String, value: String) {
if (localProperties.get() == null) {
localProperties.set(new Properties())
@@ -255,6 +259,10 @@ class SparkContext(
}
}
+ /**
+ * Get a local property set in this thread, or null if it is missing. See
+ * [[org.apache.spark.SparkContext.setLocalProperty]].
+ */
def getLocalProperty(key: String): String =
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
@@ -265,7 +273,7 @@ class SparkContext(
}
/**
- * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * Assigns a group ID to all the jobs started by this thread until the group ID is set to a
* different value or cleared.
*
* Often, a unit of execution in an application consists of multiple Spark actions or jobs.
@@ -288,7 +296,7 @@ class SparkContext(
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
}
- /** Clear the job group id and its description. */
+ /** Clear the current thread's job group ID and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
@@ -337,29 +345,42 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
- * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
- * etc).
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
+ * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
+ * using the older MapReduce API (`org.apache.hadoop.mapred`).
+ *
+ * @param conf JobConf for setting up the dataset
+ * @param inputFormatClass Class of the [[InputFormat]]
+ * @param keyClass Class of the keys
+ * @param valueClass Class of the values
+ * @param minSplits Minimum number of Hadoop Splits to generate.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
*/
- def hadoopRDD[K, V](
+ def hadoopRDD[K: ClassTag, V: ClassTag](
conf: JobConf,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int = defaultMinSplits
+ minSplits: Int = defaultMinSplits,
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
- new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
+ new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
- def hadoopFile[K, V](
+ def hadoopFile[K: ClassTag, V: ClassTag](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int = defaultMinSplits
+ minSplits: Int = defaultMinSplits,
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
@@ -371,7 +392,8 @@ class SparkContext(
inputFormatClass,
keyClass,
valueClass,
- minSplits)
+ minSplits,
+ cloneRecords)
}
/**
@@ -382,14 +404,15 @@ class SparkContext(
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
* }}}
*/
- def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
- (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
- : RDD[(K, V)] = {
+ def hadoopFile[K, V, F <: InputFormat[K, V]]
+ (path: String, minSplits: Int, cloneRecords: Boolean = true)
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
hadoopFile(path,
- fm.runtimeClass.asInstanceOf[Class[F]],
- km.runtimeClass.asInstanceOf[Class[K]],
- vm.runtimeClass.asInstanceOf[Class[V]],
- minSplits)
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]],
+ minSplits,
+ cloneRecords)
}
/**
@@ -400,61 +423,67 @@ class SparkContext(
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
* }}}
*/
- def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, cloneRecords: Boolean = true)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
- hadoopFile[K, V, F](path, defaultMinSplits)
+ hadoopFile[K, V, F](path, defaultMinSplits, cloneRecords)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
- def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]
+ (path: String, cloneRecords: Boolean = true)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
newAPIHadoopFile(
- path,
- fm.runtimeClass.asInstanceOf[Class[F]],
- km.runtimeClass.asInstanceOf[Class[K]],
- vm.runtimeClass.asInstanceOf[Class[V]])
+ path,
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]],
+ cloneRecords = cloneRecords)
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
- def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
+ def newAPIHadoopFile[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]](
path: String,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ conf: Configuration = hadoopConfiguration,
+ cloneRecords: Boolean = true): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
- new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
+ new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf, cloneRecords)
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
- def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
+ def newAPIHadoopRDD[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]](
conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
- vClass: Class[V]): RDD[(K, V)] = {
- new NewHadoopRDD(this, fClass, kClass, vClass, conf)
+ vClass: Class[V],
+ cloneRecords: Boolean = true): RDD[(K, V)] = {
+ new NewHadoopRDD(this, fClass, kClass, vClass, conf, cloneRecords)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
- def sequenceFile[K, V](path: String,
+ def sequenceFile[K: ClassTag, V: ClassTag](path: String,
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int
+ minSplits: Int,
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
- hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
+ hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
- sequenceFile(path, keyClass, valueClass, defaultMinSplits)
+ def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V],
+ cloneRecords: Boolean = true): RDD[(K, V)] =
+ sequenceFile(path, keyClass, valueClass, defaultMinSplits, cloneRecords)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -472,17 +501,18 @@ class SparkContext(
* for the appropriate type. In addition, we pass the converter a ClassTag of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
- def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
- (implicit km: ClassTag[K], vm: ClassTag[V],
- kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
+ def sequenceFile[K, V]
+ (path: String, minSplits: Int = defaultMinSplits, cloneRecords: Boolean = true)
+ (implicit km: ClassTag[K], vm: ClassTag[V],
+ kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
val writables = hadoopFile(path, format,
kc.writableClass(km).asInstanceOf[Class[Writable]],
- vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits)
- writables.map{case (k,v) => (kc.convert(k), vc.convert(v))}
+ vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits, cloneRecords)
+ writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) }
}
/**
@@ -517,15 +547,15 @@ class SparkContext(
// Methods for creating shared variables
/**
- * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the driver can access the accumulator's `value`.
+ * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
+ * values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
- * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
- * Only the driver can access the accumuable's `value`.
+ * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
+ * with `+=`. Only the driver can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
@@ -538,14 +568,16 @@ class SparkContext(
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
- def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ (initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
/**
- * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for
- * reading it in distributed functions. The variable will be sent to each cluster only once.
+ * Broadcast a read-only variable to the cluster, returning a
+ * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
+ * The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
@@ -667,10 +699,10 @@ class SparkContext(
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (SparkHadoopUtil.get.isYarnMode()) {
- // In order for this to work on yarn the user must specify the --addjars option to
- // the client to upload the file into the distributed cache to make it show up in the
- // current working directory.
+ if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") {
+ // In order for this to work in yarn standalone mode the user must specify the
+ // --addjars option to the client to upload the file into the distributed cache
+ // of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
env.httpFileServer.addJar(new File(fileName))
@@ -754,8 +786,11 @@ class SparkContext(
private[spark] def getCallSite(): String = {
val callSite = getLocalProperty("externalCallSite")
- if (callSite == null) return Utils.formatSparkCallSite
- callSite
+ if (callSite == null) {
+ Utils.formatSparkCallSite
+ } else {
+ callSite
+ }
}
/**
@@ -905,7 +940,7 @@ class SparkContext(
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
- return f
+ f
}
/**
@@ -917,7 +952,7 @@ class SparkContext(
val path = new Path(dir, UUID.randomUUID().toString)
val fs = path.getFileSystem(hadoopConfiguration)
fs.mkdirs(path)
- fs.getFileStatus(path).getPath().toString
+ fs.getFileStatus(path).getPath.toString
}
}
@@ -1010,7 +1045,8 @@ object SparkContext {
implicit def stringToText(s: String) = new Text(s)
- private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = {
+ private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T])
+ : ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]],
@@ -1033,7 +1069,9 @@ object SparkContext {
implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get)
- implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
+ implicit def bytesWritableConverter() = {
+ simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
+ }
implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
@@ -1049,7 +1087,8 @@ object SparkContext {
if (uri != null) {
val uriStr = uri.toString
if (uriStr.startsWith("jar:file:")) {
- // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar
+ // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class",
+ // so pull out the /path/foo.jar
List(uriStr.substring("jar:file:".length, uriStr.indexOf('!')))
} else {
Nil
@@ -1072,7 +1111,7 @@ object SparkContext {
* parameters that are passed as the default value of null, instead of throwing an exception
* like SparkConf would.
*/
- private def updatedConf(
+ private[spark] def updatedConf(
conf: SparkConf,
master: String,
appName: String,
@@ -1203,7 +1242,7 @@ object SparkContext {
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
val scheduler = new TaskSchedulerImpl(sc)
- val coarseGrained = sc.conf.get("spark.mesos.coarse", "false").toBoolean
+ val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false)
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, sc, url, appName)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 2e36ccb9a0..ed788560e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -54,7 +54,11 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) {
+ val conf: SparkConf) extends Logging {
+
+ // A mapping of thread ID to amount of memory used for shuffle in bytes
+ // All accesses should be manually synchronized
+ val shuffleMemoryMap = mutable.HashMap[Long, Long]()
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -128,16 +132,6 @@ object SparkEnv extends Logging {
conf.set("spark.driver.port", boundPort.toString)
}
- // set only if unset until now.
- if (!conf.contains("spark.hostPort")) {
- if (!isDriver){
- // unexpected
- Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
- }
- Utils.checkHost(hostname)
- conf.set("spark.hostPort", hostname + ":" + boundPort)
- }
-
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@@ -162,7 +156,7 @@ object SparkEnv extends Logging {
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverHost: String = conf.get("spark.driver.host", "localhost")
- val driverPort: Int = conf.get("spark.driver.port", "7077").toInt
+ val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
val timeout = AkkaUtils.lookupTimeout(conf)
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 618d95015f..4e63117a51 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -134,28 +134,28 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
format = conf.value.getOutputFormat()
.asInstanceOf[OutputFormat[AnyRef,AnyRef]]
}
- return format
+ format
}
private def getOutputCommitter(): OutputCommitter = {
if (committer == null) {
committer = conf.value.getOutputCommitter
}
- return committer
+ committer
}
private def getJobContext(): JobContext = {
if (jobContext == null) {
jobContext = newJobContext(conf.value, jID.value)
}
- return jobContext
+ jobContext
}
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
taskContext = newTaskAttemptContext(conf.value, taID.value)
}
- return taskContext
+ taskContext
}
private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
@@ -182,19 +182,18 @@ object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
- return new JobID(jobtrackerID, id)
+ new JobID(jobtrackerID, id)
}
def createPathFromString(path: String, conf: JobConf): Path = {
if (path == null) {
throw new IllegalArgumentException("Output path is null")
}
- var outputPath = new Path(path)
+ val outputPath = new Path(path)
val fs = outputPath.getFileSystem(conf)
if (outputPath == null || fs == null) {
throw new IllegalArgumentException("Incorrectly formatted output path")
}
- outputPath = outputPath.makeQualified(fs)
- return outputPath
+ outputPath.makeQualified(fs)
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index da30cf619a..b0dedc6f4e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -207,13 +207,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
* e.g. for the array
* [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
* e.g 1<=x<10 , 10<=x<20, 20<=x<50
- * And on the input of 1 and 50 we would have a histogram of 1,0,0
- *
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
* Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
* from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
- * buckets array must be at least two elements
+ * buckets array must be at least two elements
* All NaN entries are treated the same. If you have a NaN bucket it must be
* the maximum value of the last position and all NaN entries will be counted
* in that bucket.
@@ -225,6 +225,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
srdd.histogram(buckets.map(_.toDouble), evenBuckets)
}
+
+ /** Assign a name to this RDD */
+ def setName(name: String): JavaDoubleRDD = {
+ srdd.setName(name)
+ this
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 55c87450ac..0fb7e195b3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -647,6 +647,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaRDD[(K, Long)] = {
rdd.countApproxDistinctByKey(relativeSD, numPartitions)
}
+
+ /** Assign a name to this RDD */
+ def setName(name: String): JavaPairRDD[K, V] = {
+ rdd.setName(name)
+ this
+ }
}
object JavaPairRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 037cd1c774..7d48ce01cf 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -127,6 +127,12 @@ JavaRDDLike[T, JavaRDD[T]] {
wrapRDD(rdd.subtract(other, p))
override def toString = rdd.toString
+
+ /** Assign a name to this RDD */
+ def setName(name: String): JavaRDD[T] = {
+ rdd.setName(name)
+ this
+ }
}
object JavaRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 924d8af060..ebbbbd8806 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -245,6 +245,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
+ * Return an array that contains all of the elements in this RDD.
+ */
+ def toArray(): JList[T] = collect()
+
+ /**
* Return an array that contains all of the elements in a specific partition of this RDD.
*/
def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
@@ -455,4 +460,5 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def countApproxDistinct(relativeSD: Double = 0.05): Long = rdd.countApproxDistinct(relativeSD)
+ def name(): String = rdd.name
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index e93b10fd7e..7a6f044965 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -425,6 +425,51 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def clearCallSite() {
sc.clearCallSite()
}
+
+ /**
+ * Set a local property that affects jobs submitted from this thread, such as the
+ * Spark fair scheduler pool.
+ */
+ def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value)
+
+ /**
+ * Get a local property set in this thread, or null if it is missing. See
+ * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]].
+ */
+ def getLocalProperty(key: String): String = sc.getLocalProperty(key)
+
+ /**
+ * Assigns a group ID to all the jobs started by this thread until the group ID is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]]
+ * to cancel all running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description");
+ * rdd.map(...).count();
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel");
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String): Unit = sc.setJobGroup(groupId, description)
+
+ /** Clear the current thread's job group ID and its description. */
+ def clearJobGroup(): Unit = sc.clearJobGroup()
+
+ /**
+ * Cancel active jobs for the specified group. See
+ * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information.
+ */
+ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId)
+
+ /** Cancel all jobs that have been scheduled or are running. */
+ def cancelAllJobs(): Unit = sc.cancelAllJobs()
}
object JavaSparkContext {
@@ -436,5 +481,12 @@ object JavaSparkContext {
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to SparkContext.
*/
- def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray
+ def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray
+
+ /**
+ * Find the JAR that contains the class of a particular object, to make it easy for users
+ * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in
+ * your driver program.
+ */
+ def jarOfObject(obj: AnyRef): Array[String] = SparkContext.jarOfObject(obj).toArray
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 32cc70e8c9..82527fe663 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -41,7 +41,7 @@ private[spark] class PythonRDD[T: ClassTag](
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
- val bufferSize = conf.get("spark.buffer.size", "65536").toInt
+ val bufferSize = conf.getInt("spark.buffer.size", 65536)
override def getPartitions = parent.partitions
@@ -95,7 +95,7 @@ private[spark] class PythonRDD[T: ClassTag](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
- return new Iterator[Array[Byte]] {
+ val stdoutIterator = new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
@@ -156,6 +156,7 @@ private[spark] class PythonRDD[T: ClassTag](
def hasNext = _nextObj.length != 0
}
+ stdoutIterator
}
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
@@ -250,7 +251,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
Utils.checkHost(serverHost, "Expected hostname")
- val bufferSize = SparkEnv.get.conf.get("spark.buffer.size", "65536").toInt
+ val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index db596d5fcc..0eacda3d7d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -92,8 +92,8 @@ private object HttpBroadcast extends Logging {
def initialize(isDriver: Boolean, conf: SparkConf) {
synchronized {
if (!initialized) {
- bufferSize = conf.get("spark.buffer.size", "65536").toInt
- compress = conf.get("spark.broadcast.compress", "true").toBoolean
+ bufferSize = conf.getInt("spark.buffer.size", 65536)
+ compress = conf.getBoolean("spark.broadcast.compress", true)
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 9530938278..1d295c62bc 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -180,7 +180,7 @@ extends Logging {
initialized = false
}
- lazy val BLOCK_SIZE = conf.get("spark.broadcast.blockSize", "4096").toInt * 1024
+ lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
@@ -203,16 +203,16 @@ extends Logging {
}
bais.close()
- var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ val tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
- return tInfo
+ tInfo
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
- var retByteArray = new Array[Byte](totalBytes)
+ val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
new file mode 100644
index 0000000000..e133893f6c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.Map
+import scala.concurrent._
+
+import akka.actor._
+import akka.pattern.ask
+import org.apache.log4j.{Level, Logger}
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.{DriverState, Master}
+import org.apache.spark.util.{AkkaUtils, Utils}
+import akka.actor.Actor.emptyBehavior
+import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
+
+/**
+ * Proxy that relays messages to the driver.
+ */
+class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging {
+ var masterActor: ActorSelection = _
+ val timeout = AkkaUtils.askTimeout(conf)
+
+ override def preStart() = {
+ masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master))
+
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+
+ println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}")
+
+ driverArgs.cmd match {
+ case "launch" =>
+ // TODO: We could add an env variable here and intercept it in `sc.addJar` that would
+ // truncate filesystem paths similar to what YARN does. For now, we just require
+ // people call `addJar` assuming the jar is in the same directory.
+ val env = Map[String, String]()
+ System.getenv().foreach{case (k, v) => env(k) = v}
+
+ val mainClass = "org.apache.spark.deploy.worker.DriverWrapper"
+ val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++
+ driverArgs.driverOptions, env)
+
+ val driverDescription = new DriverDescription(
+ driverArgs.jarUrl,
+ driverArgs.memory,
+ driverArgs.cores,
+ driverArgs.supervise,
+ command)
+
+ masterActor ! RequestSubmitDriver(driverDescription)
+
+ case "kill" =>
+ val driverId = driverArgs.driverId
+ val killFuture = masterActor ! RequestKillDriver(driverId)
+ }
+ }
+
+ /* Find out driver status then exit the JVM */
+ def pollAndReportStatus(driverId: String) {
+ println(s"... waiting before polling master for driver state")
+ Thread.sleep(5000)
+ println("... polling master for driver state")
+ val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout)
+ .mapTo[DriverStatusResponse]
+ val statusResponse = Await.result(statusFuture, timeout)
+
+ statusResponse.found match {
+ case false =>
+ println(s"ERROR: Cluster master did not recognize $driverId")
+ System.exit(-1)
+ case true =>
+ println(s"State of $driverId is ${statusResponse.state.get}")
+ // Worker node, if present
+ (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match {
+ case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) =>
+ println(s"Driver running on $hostPort ($id)")
+ case _ =>
+ }
+ // Exception, if present
+ statusResponse.exception.map { e =>
+ println(s"Exception from cluster was: $e")
+ System.exit(-1)
+ }
+ System.exit(0)
+ }
+ }
+
+ override def receive = {
+
+ case SubmitDriverResponse(success, driverId, message) =>
+ println(message)
+ if (success) pollAndReportStatus(driverId.get) else System.exit(-1)
+
+ case KillDriverResponse(driverId, success, message) =>
+ println(message)
+ if (success) pollAndReportStatus(driverId) else System.exit(-1)
+
+ case DisassociatedEvent(_, remoteAddress, _) =>
+ println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
+ System.exit(-1)
+
+ case AssociationErrorEvent(cause, _, remoteAddress, _) =>
+ println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
+ println(s"Cause was: $cause")
+ System.exit(-1)
+ }
+}
+
+/**
+ * Executable utility for starting and terminating drivers inside of a standalone cluster.
+ */
+object Client {
+ def main(args: Array[String]) {
+ val conf = new SparkConf()
+ val driverArgs = new ClientArguments(args)
+
+ if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) {
+ conf.set("spark.akka.logLifecycleEvents", "true")
+ }
+ conf.set("spark.akka.askTimeout", "10")
+ conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING"))
+ Logger.getRootLogger.setLevel(driverArgs.logLevel)
+
+ // TODO: See if we can initialize akka so return messages are sent back using the same TCP
+ // flow. Else, this (sadly) requires the DriverClient be routable from the Master.
+ val (actorSystem, _) = AkkaUtils.createActorSystem(
+ "driverClient", Utils.localHostName(), 0, false, conf)
+
+ actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
+
+ actorSystem.awaitTermination()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
new file mode 100644
index 0000000000..db67c6d1bb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.net.URL
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.log4j.Level
+
+/**
+ * Command-line parser for the driver client.
+ */
+private[spark] class ClientArguments(args: Array[String]) {
+ val defaultCores = 1
+ val defaultMemory = 512
+
+ var cmd: String = "" // 'launch' or 'kill'
+ var logLevel = Level.WARN
+
+ // launch parameters
+ var master: String = ""
+ var jarUrl: String = ""
+ var mainClass: String = ""
+ var supervise: Boolean = false
+ var memory: Int = defaultMemory
+ var cores: Int = defaultCores
+ private var _driverOptions = ListBuffer[String]()
+ def driverOptions = _driverOptions.toSeq
+
+ // kill parameters
+ var driverId: String = ""
+
+ parse(args.toList)
+
+ def parse(args: List[String]): Unit = args match {
+ case ("--cores" | "-c") :: value :: tail =>
+ cores = value.toInt
+ parse(tail)
+
+ case ("--memory" | "-m") :: value :: tail =>
+ memory = value.toInt
+ parse(tail)
+
+ case ("--supervise" | "-s") :: tail =>
+ supervise = true
+ parse(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case ("--verbose" | "-v") :: tail =>
+ logLevel = Level.INFO
+ parse(tail)
+
+ case "launch" :: _master :: _jarUrl :: _mainClass :: tail =>
+ cmd = "launch"
+
+ try {
+ new URL(_jarUrl)
+ } catch {
+ case e: Exception =>
+ println(s"Jar url '${_jarUrl}' is not a valid URL.")
+ println(s"Jar must be in URL format (e.g. hdfs://XX, file://XX)")
+ printUsageAndExit(-1)
+ }
+
+ jarUrl = _jarUrl
+ master = _master
+ mainClass = _mainClass
+ _driverOptions ++= tail
+
+ case "kill" :: _master :: _driverId :: tail =>
+ cmd = "kill"
+ master = _master
+ driverId = _driverId
+
+ case _ =>
+ printUsageAndExit(1)
+ }
+
+ /**
+ * Print usage and exit JVM with the given exit code.
+ */
+ def printUsageAndExit(exitCode: Int) {
+ // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars
+ // separately similar to in the YARN client.
+ val usage =
+ s"""
+ |Usage: DriverClient [options] launch <active-master> <jar-url> <main-class> [driver options]
+ |Usage: DriverClient kill <active-master> <driver-id>
+ |
+ |Options:
+ | -c CORES, --cores CORES Number of cores to request (default: $defaultCores)
+ | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory)
+ | -s, --supervise Whether to restart the driver on failure
+ | -v, --verbose Print more debugging output
+ """.stripMargin
+ System.err.println(usage)
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 275331724a..5e824e1a67 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -20,12 +20,12 @@ package org.apache.spark.deploy
import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
-import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
+import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
+import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.RecoveryState.MasterState
-import org.apache.spark.deploy.worker.ExecutorRunner
+import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
import org.apache.spark.util.Utils
-
private[deploy] sealed trait DeployMessage extends Serializable
/** Contains messages sent between Scheduler actor nodes. */
@@ -54,7 +54,14 @@ private[deploy] object DeployMessages {
exitStatus: Option[Int])
extends DeployMessage
- case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
+ case class DriverStateChanged(
+ driverId: String,
+ state: DriverState,
+ exception: Option[Exception])
+ extends DeployMessage
+
+ case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription],
+ driverIds: Seq[String])
case class Heartbeat(workerId: String) extends DeployMessage
@@ -76,14 +83,18 @@ private[deploy] object DeployMessages {
sparkHome: String)
extends DeployMessage
- // Client to Master
+ case class LaunchDriver(driverId: String, driverDesc: DriverDescription) extends DeployMessage
+
+ case class KillDriver(driverId: String) extends DeployMessage
+
+ // AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
case class MasterChangeAcknowledged(appId: String)
- // Master to Client
+ // Master to AppClient
case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
@@ -97,11 +108,28 @@ private[deploy] object DeployMessages {
case class ApplicationRemoved(message: String)
- // Internal message in Client
+ // DriverClient <-> Master
+
+ case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage
+
+ case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String)
+ extends DeployMessage
+
+ case class RequestKillDriver(driverId: String) extends DeployMessage
+
+ case class KillDriverResponse(driverId: String, success: Boolean, message: String)
+ extends DeployMessage
+
+ case class RequestDriverStatus(driverId: String) extends DeployMessage
+
+ case class DriverStatusResponse(found: Boolean, state: Option[DriverState],
+ workerId: Option[String], workerHostPort: Option[String], exception: Option[Exception])
+
+ // Internal message in AppClient
- case object StopClient
+ case object StopAppClient
- // Master to Worker & Client
+ // Master to Worker & AppClient
case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
@@ -113,6 +141,7 @@ private[deploy] object DeployMessages {
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
+ activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo],
status: MasterState) {
Utils.checkHost(host, "Required hostname")
@@ -128,14 +157,15 @@ private[deploy] object DeployMessages {
// Worker to WorkerWebUI
case class WorkerStateResponse(host: String, port: Int, workerId: String,
- executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String,
+ executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner],
+ drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String,
cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}
- // Actor System to Worker
+ // Liveness checks in various places
case object SendHeartbeat
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
new file mode 100644
index 0000000000..58c95dc4f9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+private[spark] class DriverDescription(
+ val jarUrl: String,
+ val mem: Int,
+ val cores: Int,
+ val supervise: Boolean,
+ val command: Command)
+ extends Serializable {
+
+ override def toString: String = s"DriverDescription (${command.mainClass})"
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 481026eaa2..1415e2f3d1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -33,16 +33,17 @@ import org.apache.spark.deploy.master.Master
import org.apache.spark.util.AkkaUtils
/**
- * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
- * and a listener for cluster events, and calls back the listener when various events occur.
+ * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL,
+ * an app description, and a listener for cluster events, and calls back the listener when various
+ * events occur.
*
* @param masterUrls Each url should look like spark://host:port.
*/
-private[spark] class Client(
+private[spark] class AppClient(
actorSystem: ActorSystem,
masterUrls: Array[String],
appDescription: ApplicationDescription,
- listener: ClientListener,
+ listener: AppClientListener,
conf: SparkConf)
extends Logging {
@@ -155,7 +156,7 @@ private[spark] class Client(
case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) =>
logWarning(s"Could not connect to $address: $cause")
- case StopClient =>
+ case StopAppClient =>
markDead()
sender ! true
context.stop(self)
@@ -188,7 +189,7 @@ private[spark] class Client(
if (actor != null) {
try {
val timeout = AkkaUtils.askTimeout(conf)
- val future = actor.ask(StopClient)(timeout)
+ val future = actor.ask(StopAppClient)(timeout)
Await.result(future, timeout)
} catch {
case e: TimeoutException =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala
index be7a11bd15..55d4ef1b31 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala
@@ -24,7 +24,7 @@ package org.apache.spark.deploy.client
*
* Users of this API should *not* block inside the callback methods.
*/
-private[spark] trait ClientListener {
+private[spark] trait AppClientListener {
def connected(appId: String): Unit
/** Disconnection may be a temporary state, as we fail over to a new Master. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 28ebbdc66b..ffa909c26b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -23,7 +23,7 @@ import org.apache.spark.deploy.{Command, ApplicationDescription}
private[spark] object TestClient {
- class TestListener extends ClientListener with Logging {
+ class TestListener extends AppClientListener with Logging {
def connected(id: String) {
logInfo("Connected to master, got app ID " + id)
}
@@ -51,7 +51,7 @@ private[spark] object TestClient {
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
"dummy-spark-home", "ignored")
val listener = new TestListener
- val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf)
+ val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf)
client.start()
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
new file mode 100644
index 0000000000..33377931d6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import java.util.Date
+
+import org.apache.spark.deploy.DriverDescription
+
+private[spark] class DriverInfo(
+ val startTime: Long,
+ val id: String,
+ val desc: DriverDescription,
+ val submitDate: Date)
+ extends Serializable {
+
+ @transient var state: DriverState.Value = DriverState.SUBMITTED
+ /* If we fail when launching the driver, the exception is stored here. */
+ @transient var exception: Option[Exception] = None
+ /* Most recent worker assigned to this driver */
+ @transient var worker: Option[WorkerInfo] = None
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
new file mode 100644
index 0000000000..26a68bade3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object DriverState extends Enumeration {
+
+ type DriverState = Value
+
+ // SUBMITTED: Submitted but not yet scheduled on a worker
+ // RUNNING: Has been allocated to a worker to run
+ // FINISHED: Previously ran and exited cleanly
+ // RELAUNCHING: Exited non-zero or due to worker failure, but has not yet started running again
+ // UNKNOWN: The state of the driver is temporarily not known due to master failure recovery
+ // KILLED: A user manually killed this driver
+ // FAILED: The driver exited non-zero and was not supervised
+ // ERROR: Unable to run or restart due to an unrecoverable error (e.g. missing jar file)
+ val SUBMITTED, RUNNING, FINISHED, RELAUNCHING, UNKNOWN, KILLED, FAILED, ERROR = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index 043945a211..74bb9ebf1d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -19,8 +19,6 @@ package org.apache.spark.deploy.master
import java.io._
-import scala.Serializable
-
import akka.serialization.Serialization
import org.apache.spark.Logging
@@ -47,6 +45,15 @@ private[spark] class FileSystemPersistenceEngine(
new File(dir + File.separator + "app_" + app.id).delete()
}
+ override def addDriver(driver: DriverInfo) {
+ val driverFile = new File(dir + File.separator + "driver_" + driver.id)
+ serializeIntoFile(driverFile, driver)
+ }
+
+ override def removeDriver(driver: DriverInfo) {
+ new File(dir + File.separator + "driver_" + driver.id).delete()
+ }
+
override def addWorker(worker: WorkerInfo) {
val workerFile = new File(dir + File.separator + "worker_" + worker.id)
serializeIntoFile(workerFile, worker)
@@ -56,13 +63,15 @@ private[spark] class FileSystemPersistenceEngine(
new File(dir + File.separator + "worker_" + worker.id).delete()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_"))
+ val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, workers)
+ (apps, drivers, workers)
}
private def serializeIntoFile(file: File, value: AnyRef) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 6617b7100f..d9ea96afcf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -23,19 +23,22 @@ import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
+import scala.util.Random
import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException}
-import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
+
+import org.apache.spark.{SparkConf, Logging, SparkException}
+import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.deploy.master.DriverState.DriverState
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()
@@ -43,13 +46,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val conf = new SparkConf
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
- val WORKER_TIMEOUT = conf.get("spark.worker.timeout", "60").toLong * 1000
- val RETAINED_APPLICATIONS = conf.get("spark.deploy.retainedApplications", "200").toInt
- val REAPER_ITERATIONS = conf.get("spark.dead.worker.persistence", "15").toInt
+ val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
+ val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
+ val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
- var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
val actorToWorker = new HashMap[ActorRef, WorkerInfo]
@@ -59,9 +61,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val idToApp = new HashMap[String, ApplicationInfo]
val actorToApp = new HashMap[ActorRef, ApplicationInfo]
val addressToApp = new HashMap[Address, ApplicationInfo]
-
val waitingApps = new ArrayBuffer[ApplicationInfo]
val completedApps = new ArrayBuffer[ApplicationInfo]
+ var nextAppNumber = 0
+
+ val drivers = new HashSet[DriverInfo]
+ val completedDrivers = new ArrayBuffer[DriverInfo]
+ val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling
+ var nextDriverNumber = 0
Utils.checkHost(host, "Expected hostname")
@@ -142,14 +149,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
override def receive = {
case ElectedLeader => {
- val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
- state = if (storedApps.isEmpty && storedWorkers.isEmpty)
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
+ state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty)
RecoveryState.ALIVE
else
RecoveryState.RECOVERING
logInfo("I have been elected leader! New state: " + state)
if (state == RecoveryState.RECOVERING) {
- beginRecovery(storedApps, storedWorkers)
+ beginRecovery(storedApps, storedDrivers, storedWorkers)
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
}
}
@@ -176,6 +183,69 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ case RequestSubmitDriver(description) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state."
+ sender ! SubmitDriverResponse(false, None, msg)
+ } else {
+ logInfo("Driver submitted " + description.command.mainClass)
+ val driver = createDriver(description)
+ persistenceEngine.addDriver(driver)
+ waitingDrivers += driver
+ drivers.add(driver)
+ schedule()
+
+ // TODO: It might be good to instead have the submission client poll the master to determine
+ // the current status of the driver. For now it's simply "fire and forget".
+
+ sender ! SubmitDriverResponse(true, Some(driver.id),
+ s"Driver successfully submitted as ${driver.id}")
+ }
+ }
+
+ case RequestKillDriver(driverId) => {
+ if (state != RecoveryState.ALIVE) {
+ val msg = s"Can only kill drivers in ALIVE state. Current state: $state."
+ sender ! KillDriverResponse(driverId, success = false, msg)
+ } else {
+ logInfo("Asked to kill driver " + driverId)
+ val driver = drivers.find(_.id == driverId)
+ driver match {
+ case Some(d) =>
+ if (waitingDrivers.contains(d)) {
+ waitingDrivers -= d
+ self ! DriverStateChanged(driverId, DriverState.KILLED, None)
+ }
+ else {
+ // We just notify the worker to kill the driver here. The final bookkeeping occurs
+ // on the return path when the worker submits a state change back to the master
+ // to notify it that the driver was successfully killed.
+ d.worker.foreach { w =>
+ w.actor ! KillDriver(driverId)
+ }
+ }
+ // TODO: It would be nice for this to be a synchronous response
+ val msg = s"Kill request for $driverId submitted"
+ logInfo(msg)
+ sender ! KillDriverResponse(driverId, success = true, msg)
+ case None =>
+ val msg = s"Driver $driverId has already finished or does not exist"
+ logWarning(msg)
+ sender ! KillDriverResponse(driverId, success = false, msg)
+ }
+ }
+ }
+
+ case RequestDriverStatus(driverId) => {
+ (drivers ++ completedDrivers).find(_.id == driverId) match {
+ case Some(driver) =>
+ sender ! DriverStatusResponse(found = true, Some(driver.state),
+ driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)
+ case None =>
+ sender ! DriverStatusResponse(found = false, None, None, None, None)
+ }
+ }
+
case RegisterApplication(description) => {
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
@@ -218,6 +288,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ case DriverStateChanged(driverId, state, exception) => {
+ state match {
+ case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED =>
+ removeDriver(driverId, state, exception)
+ case _ =>
+ throw new Exception(s"Received unexpected state update for driver $driverId: $state")
+ }
+ }
+
case Heartbeat(workerId) => {
idToWorker.get(workerId) match {
case Some(workerInfo) =>
@@ -239,7 +318,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (canCompleteRecovery) { completeRecovery() }
}
- case WorkerSchedulerStateResponse(workerId, executors) => {
+ case WorkerSchedulerStateResponse(workerId, executors, driverIds) => {
idToWorker.get(workerId) match {
case Some(worker) =>
logInfo("Worker has been re-registered: " + workerId)
@@ -252,6 +331,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
worker.addExecutor(execInfo)
execInfo.copyState(exec)
}
+
+ for (driverId <- driverIds) {
+ drivers.find(_.id == driverId).foreach { driver =>
+ driver.worker = Some(worker)
+ driver.state = DriverState.RUNNING
+ worker.drivers(driverId) = driver
+ }
+ }
case None =>
logWarning("Scheduler state from unknown worker: " + workerId)
}
@@ -269,7 +356,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
case RequestMasterState => {
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
- state)
+ drivers.toArray, completedDrivers.toArray, state)
}
case CheckForWorkerTimeOut => {
@@ -285,7 +372,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
apps.count(_.state == ApplicationState.UNKNOWN) == 0
- def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
+ def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo],
+ storedWorkers: Seq[WorkerInfo]) {
for (app <- storedApps) {
logInfo("Trying to recover app: " + app.id)
try {
@@ -297,6 +385,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ for (driver <- storedDrivers) {
+ // Here we just read in the list of drivers. Any drivers associated with now-lost workers
+ // will be re-launched when we detect that the worker is missing.
+ drivers += driver
+ }
+
for (worker <- storedWorkers) {
logInfo("Trying to recover worker: " + worker.id)
try {
@@ -320,6 +414,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
+ // Reschedule drivers which were not claimed by any workers
+ drivers.filter(_.worker.isEmpty).foreach { d =>
+ logWarning(s"Driver ${d.id} was not found after master recovery")
+ if (d.desc.supervise) {
+ logWarning(s"Re-launching ${d.id}")
+ relaunchDriver(d)
+ } else {
+ removeDriver(d.id, DriverState.ERROR, None)
+ logWarning(s"Did not re-launch ${d.id} because it was not supervised")
+ }
+ }
+
state = RecoveryState.ALIVE
schedule()
logInfo("Recovery complete - resuming operations!")
@@ -340,6 +446,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
*/
def schedule() {
if (state != RecoveryState.ALIVE) { return }
+
+ // First schedule drivers, they take strict precedence over applications
+ val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers
+ for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) {
+ for (driver <- waitingDrivers) {
+ if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
+ launchDriver(worker, driver)
+ waitingDrivers -= driver
+ }
+ }
+ }
+
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
// in the queue, then the second app, etc.
if (spreadOutApps) {
@@ -426,9 +544,25 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.id, ExecutorState.LOST, Some("worker lost"), None)
exec.application.removeExecutor(exec)
}
+ for (driver <- worker.drivers.values) {
+ if (driver.desc.supervise) {
+ logInfo(s"Re-launching ${driver.id}")
+ relaunchDriver(driver)
+ } else {
+ logInfo(s"Not re-launching ${driver.id} because it was not supervised")
+ removeDriver(driver.id, DriverState.ERROR, None)
+ }
+ }
persistenceEngine.removeWorker(worker)
}
+ def relaunchDriver(driver: DriverInfo) {
+ driver.worker = None
+ driver.state = DriverState.RELAUNCHING
+ waitingDrivers += driver
+ schedule()
+ }
+
def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
@@ -508,6 +642,41 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
}
+
+ def newDriverId(submitDate: Date): String = {
+ val appId = "driver-%s-%04d".format(DATE_FORMAT.format(submitDate), nextDriverNumber)
+ nextDriverNumber += 1
+ appId
+ }
+
+ def createDriver(desc: DriverDescription): DriverInfo = {
+ val now = System.currentTimeMillis()
+ val date = new Date(now)
+ new DriverInfo(now, newDriverId(date), desc, date)
+ }
+
+ def launchDriver(worker: WorkerInfo, driver: DriverInfo) {
+ logInfo("Launching driver " + driver.id + " on worker " + worker.id)
+ worker.addDriver(driver)
+ driver.worker = Some(worker)
+ worker.actor ! LaunchDriver(driver.id, driver.desc)
+ driver.state = DriverState.RUNNING
+ }
+
+ def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) {
+ drivers.find(d => d.id == driverId) match {
+ case Some(driver) =>
+ logInfo(s"Removing driver: $driverId")
+ drivers -= driver
+ completedDrivers += driver
+ persistenceEngine.removeDriver(driver)
+ driver.state = finalState
+ driver.exception = exception
+ driver.worker.foreach(w => w.removeDriver(driver))
+ case None =>
+ logWarning(s"Asked to remove unknown driver: $driverId")
+ }
+ }
}
private[spark] object Master {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index 94b986caf2..e3640ea4f7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -35,11 +35,15 @@ private[spark] trait PersistenceEngine {
def removeWorker(worker: WorkerInfo)
+ def addDriver(driver: DriverInfo)
+
+ def removeDriver(driver: DriverInfo)
+
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
+ def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo])
def close() {}
}
@@ -49,5 +53,8 @@ private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
override def removeApplication(app: ApplicationInfo) {}
override def addWorker(worker: WorkerInfo) {}
override def removeWorker(worker: WorkerInfo) {}
- override def readPersistedData() = (Nil, Nil)
+ override def addDriver(driver: DriverInfo) {}
+ override def removeDriver(driver: DriverInfo) {}
+
+ override def readPersistedData() = (Nil, Nil, Nil)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index e05f587b58..c5fa9cf7d7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -17,8 +17,10 @@
package org.apache.spark.deploy.master
-import akka.actor.ActorRef
import scala.collection.mutable
+
+import akka.actor.ActorRef
+
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -35,7 +37,8 @@ private[spark] class WorkerInfo(
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
+ @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info
+ @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
@transient var state: WorkerState.Value = _
@transient var coresUsed: Int = _
@transient var memoryUsed: Int = _
@@ -54,6 +57,7 @@ private[spark] class WorkerInfo(
private def init() {
executors = new mutable.HashMap
+ drivers = new mutable.HashMap
state = WorkerState.ALIVE
coresUsed = 0
memoryUsed = 0
@@ -83,6 +87,18 @@ private[spark] class WorkerInfo(
executors.values.exists(_.application == app)
}
+ def addDriver(driver: DriverInfo) {
+ drivers(driver.id) = driver
+ memoryUsed += driver.desc.mem
+ coresUsed += driver.desc.cores
+ }
+
+ def removeDriver(driver: DriverInfo) {
+ drivers -= driver.id
+ memoryUsed -= driver.desc.mem
+ coresUsed -= driver.desc.cores
+ }
+
def webUiAddress : String = {
"http://" + this.publicAddress + ":" + this.webUiPort
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 52000d4f9c..f24f49ea8a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -49,6 +49,14 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
zk.delete(WORKING_DIR + "/app_" + app.id)
}
+ override def addDriver(driver: DriverInfo) {
+ serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver)
+ }
+
+ override def removeDriver(driver: DriverInfo) {
+ zk.delete(WORKING_DIR + "/driver_" + driver.id)
+ }
+
override def addWorker(worker: WorkerInfo) {
serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
}
@@ -61,13 +69,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
zk.close()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
val appFiles = sortedFiles.filter(_.startsWith("app_"))
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
+ val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, workers)
+ (apps, drivers, workers)
}
private def serializeIntoFile(path: String, value: AnyRef) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index dbb0cb90f5..9485bfd89e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -67,11 +67,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) {
<li><strong>User:</strong> {app.desc.user}</li>
<li><strong>Cores:</strong>
{
- if (app.desc.maxCores == Integer.MAX_VALUE) {
+ if (app.desc.maxCores == None) {
"Unlimited (%s granted)".format(app.coresGranted)
} else {
"%s (%s granted, %s left)".format(
- app.desc.maxCores, app.coresGranted, app.coresLeft)
+ app.desc.maxCores.get, app.coresGranted, app.coresLeft)
}
}
</li>
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
index 4ef762892c..a9af8df552 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.master.ui
import scala.concurrent.Await
+import scala.concurrent.duration._
import scala.xml.Node
import akka.pattern.ask
@@ -26,7 +27,7 @@ import net.liftweb.json.JsonAST.JValue
import org.apache.spark.deploy.{DeployWebUI, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
+import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
@@ -56,6 +57,16 @@ private[spark] class IndexPage(parent: MasterWebUI) {
val completedApps = state.completedApps.sortBy(_.endTime).reverse
val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
+ val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class")
+ val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse
+ val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers)
+ val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse
+ val completedDriversTable = UIUtils.listingTable(driverHeaders, driverRow, completedDrivers)
+
+ // For now we only show driver information if the user has submitted drivers to the cluster.
+ // This is until we integrate the notion of drivers and applications in the UI.
+ def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0
+
val content =
<div class="row-fluid">
<div class="span12">
@@ -70,6 +81,9 @@ private[spark] class IndexPage(parent: MasterWebUI) {
<li><strong>Applications:</strong>
{state.activeApps.size} Running,
{state.completedApps.size} Completed </li>
+ <li><strong>Drivers:</strong>
+ {state.activeDrivers.size} Running,
+ {state.completedDrivers.size} Completed </li>
</ul>
</div>
</div>
@@ -84,17 +98,39 @@ private[spark] class IndexPage(parent: MasterWebUI) {
<div class="row-fluid">
<div class="span12">
<h4> Running Applications </h4>
-
{activeAppsTable}
</div>
</div>
+ <div>
+ {if (hasDrivers)
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Running Drivers </h4>
+ {activeDriversTable}
+ </div>
+ </div>
+ }
+ </div>
+
<div class="row-fluid">
<div class="span12">
<h4> Completed Applications </h4>
{completedAppsTable}
</div>
+ </div>
+
+ <div>
+ {if (hasDrivers)
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Completed Drivers </h4>
+ {completedDriversTable}
+ </div>
+ </div>
+ }
</div>;
+
UIUtils.basicSparkPage(content, "Spark Master at " + state.uri)
}
@@ -134,4 +170,20 @@ private[spark] class IndexPage(parent: MasterWebUI) {
<td>{DeployWebUI.formatDuration(app.duration)}</td>
</tr>
}
+
+ def driverRow(driver: DriverInfo): Seq[Node] = {
+ <tr>
+ <td>{driver.id} </td>
+ <td>{driver.submitDate}</td>
+ <td>{driver.worker.map(w => <a href={w.webUiAddress}>{w.id.toString}</a>).getOrElse("None")}</td>
+ <td>{driver.state}</td>
+ <td sorttable_customkey={driver.desc.cores.toString}>
+ {driver.desc.cores}
+ </td>
+ <td sorttable_customkey={driver.desc.mem.toString}>
+ {Utils.megabytesToString(driver.desc.mem.toLong)}
+ </td>
+ <td>{driver.desc.command.arguments(1)}</td>
+ </tr>
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
new file mode 100644
index 0000000000..7507bf8ad0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -0,0 +1,63 @@
+package org.apache.spark.deploy.worker
+
+import java.io.{File, FileOutputStream, IOException, InputStream}
+import java.lang.System._
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.Command
+import org.apache.spark.util.Utils
+
+/**
+ ** Utilities for running commands with the spark classpath.
+ */
+object CommandUtils extends Logging {
+ private[spark] def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java")
+
+ // SPARK-698: do not call the run.cmd script, as process.destroy()
+ // fails to kill a process tree on Windows
+ Seq(runner) ++ buildJavaOpts(command, memory, sparkHome) ++ Seq(command.mainClass) ++
+ command.arguments
+ }
+
+ private def getEnv(key: String, command: Command): Option[String] =
+ command.environment.get(key).orElse(Option(System.getenv(key)))
+
+ /**
+ * Attention: this must always be aligned with the environment variables in the run scripts and
+ * the way the JAVA_OPTS are assembled there.
+ */
+ def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ val libraryOpts = getEnv("SPARK_LIBRARY_PATH", command)
+ .map(p => List("-Djava.library.path=" + p))
+ .getOrElse(Nil)
+ val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
+ val userOpts = getEnv("SPARK_JAVA_OPTS", command).map(Utils.splitCommandString).getOrElse(Nil)
+ val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
+
+ // Figure out our classpath with the external compute-classpath script
+ val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
+ val classPath = Utils.executeAndGetOutput(
+ Seq(sparkHome + "/bin/compute-classpath" + ext),
+ extraEnvironment=command.environment)
+
+ Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts
+ }
+
+ /** Spawn a thread that will redirect a given stream to a file */
+ def redirectStream(in: InputStream, file: File) {
+ val out = new FileOutputStream(file, true)
+ // TODO: It would be nice to add a shutdown hook here that explains why the output is
+ // terminating. Otherwise if the worker dies the executor logs will silently stop.
+ new Thread("redirect output to " + file) {
+ override def run() {
+ try {
+ Utils.copyStream(in, out, true)
+ } catch {
+ case e: IOException =>
+ logInfo("Redirection to " + file + " closed: " + e.getMessage)
+ }
+ }
+ }.start()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
new file mode 100644
index 0000000000..b4df1a0dd4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import java.io._
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.Map
+
+import akka.actor.ActorRef
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileUtil, Path}
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.{Command, DriverDescription}
+import org.apache.spark.deploy.DeployMessages.DriverStateChanged
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.deploy.master.DriverState.DriverState
+
+/**
+ * Manages the execution of one driver, including automatically restarting the driver on failure.
+ */
+private[spark] class DriverRunner(
+ val driverId: String,
+ val workDir: File,
+ val sparkHome: File,
+ val driverDesc: DriverDescription,
+ val worker: ActorRef,
+ val workerUrl: String)
+ extends Logging {
+
+ @volatile var process: Option[Process] = None
+ @volatile var killed = false
+
+ // Populated once finished
+ var finalState: Option[DriverState] = None
+ var finalException: Option[Exception] = None
+ var finalExitCode: Option[Int] = None
+
+ // Decoupled for testing
+ private[deploy] def setClock(_clock: Clock) = clock = _clock
+ private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper
+ private var clock = new Clock {
+ def currentTimeMillis(): Long = System.currentTimeMillis()
+ }
+ private var sleeper = new Sleeper {
+ def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed})
+ }
+
+ /** Starts a thread to run and manage the driver. */
+ def start() = {
+ new Thread("DriverRunner for " + driverId) {
+ override def run() {
+ try {
+ val driverDir = createWorkingDirectory()
+ val localJarFilename = downloadUserJar(driverDir)
+
+ // Make sure user application jar is on the classpath
+ // TODO: If we add ability to submit multiple jars they should also be added here
+ val env = Map(driverDesc.command.environment.toSeq: _*)
+ env("SPARK_CLASSPATH") = env.getOrElse("SPARK_CLASSPATH", "") + s":$localJarFilename"
+ val newCommand = Command(driverDesc.command.mainClass,
+ driverDesc.command.arguments.map(substituteVariables), env)
+ val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
+ sparkHome.getAbsolutePath)
+ launchDriver(command, env, driverDir, driverDesc.supervise)
+ }
+ catch {
+ case e: Exception => finalException = Some(e)
+ }
+
+ val state =
+ if (killed) { DriverState.KILLED }
+ else if (finalException.isDefined) { DriverState.ERROR }
+ else {
+ finalExitCode match {
+ case Some(0) => DriverState.FINISHED
+ case _ => DriverState.FAILED
+ }
+ }
+
+ finalState = Some(state)
+
+ worker ! DriverStateChanged(driverId, state, finalException)
+ }
+ }.start()
+ }
+
+ /** Terminate this driver (or prevent it from ever starting if not yet started) */
+ def kill() {
+ synchronized {
+ process.foreach(p => p.destroy())
+ killed = true
+ }
+ }
+
+ /** Replace variables in a command argument passed to us */
+ private def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
+ case other => other
+ }
+
+ /**
+ * Creates the working directory for this driver.
+ * Will throw an exception if there are errors preparing the directory.
+ */
+ private def createWorkingDirectory(): File = {
+ val driverDir = new File(workDir, driverId)
+ if (!driverDir.exists() && !driverDir.mkdirs()) {
+ throw new IOException("Failed to create directory " + driverDir)
+ }
+ driverDir
+ }
+
+ /**
+ * Download the user jar into the supplied directory and return its local path.
+ * Will throw an exception if there are errors downloading the jar.
+ */
+ private def downloadUserJar(driverDir: File): String = {
+
+ val jarPath = new Path(driverDesc.jarUrl)
+
+ val emptyConf = new Configuration()
+ val jarFileSystem = jarPath.getFileSystem(emptyConf)
+
+ val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
+ val jarFileName = jarPath.getName
+ val localJarFile = new File(driverDir, jarFileName)
+ val localJarFilename = localJarFile.getAbsolutePath
+
+ if (!localJarFile.exists()) { // May already exist if running multiple workers on one node
+ logInfo(s"Copying user jar $jarPath to $destPath")
+ FileUtil.copy(jarFileSystem, jarPath, destPath, false, emptyConf)
+ }
+
+ if (!localJarFile.exists()) { // Verify copy succeeded
+ throw new Exception(s"Did not see expected jar $jarFileName in $driverDir")
+ }
+
+ localJarFilename
+ }
+
+ private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
+ supervise: Boolean) {
+ val builder = new ProcessBuilder(command: _*).directory(baseDir)
+ envVars.map{ case(k,v) => builder.environment().put(k, v) }
+
+ def initialize(process: Process) = {
+ // Redirect stdout and stderr to files
+ val stdout = new File(baseDir, "stdout")
+ CommandUtils.redirectStream(process.getInputStream, stdout)
+
+ val stderr = new File(baseDir, "stderr")
+ val header = "Launch Command: %s\n%s\n\n".format(
+ command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, Charsets.UTF_8)
+ CommandUtils.redirectStream(process.getErrorStream, stderr)
+ }
+ runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
+ }
+
+ private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit,
+ supervise: Boolean) {
+ // Time to wait between submission retries.
+ var waitSeconds = 1
+ // A run of this many seconds resets the exponential back-off.
+ val successfulRunDuration = 5
+
+ var keepTrying = !killed
+
+ while (keepTrying) {
+ logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))
+
+ synchronized {
+ if (killed) { return }
+ process = Some(command.start())
+ initialize(process.get)
+ }
+
+ val processStart = clock.currentTimeMillis()
+ val exitCode = process.get.waitFor()
+ if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
+ waitSeconds = 1
+ }
+
+ if (supervise && exitCode != 0 && !killed) {
+ logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
+ sleeper.sleep(waitSeconds)
+ waitSeconds = waitSeconds * 2 // exponential back-off
+ }
+
+ keepTrying = supervise && exitCode != 0 && !killed
+ finalExitCode = Some(exitCode)
+ }
+ }
+}
+
+private[deploy] trait Clock {
+ def currentTimeMillis(): Long
+}
+
+private[deploy] trait Sleeper {
+ def sleep(seconds: Int)
+}
+
+// Needed because ProcessBuilder is a final class and cannot be mocked
+private[deploy] trait ProcessBuilderLike {
+ def start(): Process
+ def command: Seq[String]
+}
+
+private[deploy] object ProcessBuilderLike {
+ def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike {
+ def start() = processBuilder.start()
+ def command = processBuilder.command()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
new file mode 100644
index 0000000000..1640d5fee0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -0,0 +1,31 @@
+package org.apache.spark.deploy.worker
+
+import akka.actor._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+/**
+ * Utility object for launching driver programs such that they share fate with the Worker process.
+ */
+object DriverWrapper {
+ def main(args: Array[String]) {
+ args.toList match {
+ case workerUrl :: mainClass :: extraArgs =>
+ val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
+ Utils.localHostName(), 0, false, new SparkConf())
+ actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
+
+ // Delegate to supplied main class
+ val clazz = Class.forName(args(1))
+ val mainMethod = clazz.getMethod("main", classOf[Array[String]])
+ mainMethod.invoke(null, extraArgs.toArray[String])
+
+ actorSystem.shutdown()
+
+ case _ =>
+ System.err.println("Usage: DriverWrapper <workerUrl> <driverMainClass> [options]")
+ System.exit(-1)
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index fff9cb60c7..18885d7ca6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -18,17 +18,15 @@
package org.apache.spark.deploy.worker
import java.io._
-import java.lang.System.getenv
import akka.actor.ActorRef
import com.google.common.base.Charsets
import com.google.common.io.Files
-import org.apache.spark.{Logging}
-import org.apache.spark.deploy.{ExecutorState, ApplicationDescription}
+import org.apache.spark.Logging
+import org.apache.spark.deploy.{ExecutorState, ApplicationDescription, Command}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
-import org.apache.spark.util.Utils
/**
* Manages the execution of one executor process.
@@ -44,16 +42,17 @@ private[spark] class ExecutorRunner(
val host: String,
val sparkHome: File,
val workDir: File,
+ val workerUrl: String,
var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
- var shutdownHook: Thread = null
- private def getAppEnv(key: String): Option[String] =
- appDesc.command.environment.get(key).orElse(Option(getenv(key)))
+ // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
+ // make sense to remove this in the future.
+ var shutdownHook: Thread = null
def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) {
@@ -92,55 +91,17 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
+ case "{{WORKER_URL}}" => workerUrl
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => host
case "{{CORES}}" => cores.toString
case other => other
}
- def buildCommandSeq(): Seq[String] = {
- val command = appDesc.command
- val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
- // SPARK-698: do not call the run.cmd script, as process.destroy()
- // fails to kill a process tree on Windows
- Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- (command.arguments ++ Seq(appId)).map(substituteVariables)
- }
-
- /**
- * Attention: this must always be aligned with the environment variables in the run scripts and
- * the way the JAVA_OPTS are assembled there.
- */
- def buildJavaOpts(): Seq[String] = {
- val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH")
- .map(p => List("-Djava.library.path=" + p))
- .getOrElse(Nil)
- val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
- val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil)
- val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
-
- // Figure out our classpath with the external compute-classpath script
- val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
- val classPath = Utils.executeAndGetOutput(
- Seq(sparkHome + "/bin/compute-classpath" + ext),
- extraEnvironment=appDesc.command.environment)
-
- Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts
- }
-
- /** Spawn a thread that will redirect a given stream to a file */
- def redirectStream(in: InputStream, file: File) {
- val out = new FileOutputStream(file, true)
- new Thread("redirect output to " + file) {
- override def run() {
- try {
- Utils.copyStream(in, out, true)
- } catch {
- case e: IOException =>
- logInfo("Redirection to " + file + " closed: " + e.getMessage)
- }
- }
- }.start()
+ def getCommandSeq = {
+ val command = Command(appDesc.command.mainClass,
+ appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment)
+ CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
}
/**
@@ -155,7 +116,7 @@ private[spark] class ExecutorRunner(
}
// Launch the process
- val command = buildCommandSeq()
+ val command = getCommandSeq
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
val builder = new ProcessBuilder(command: _*).directory(executorDir)
val env = builder.environment()
@@ -172,11 +133,11 @@ private[spark] class ExecutorRunner(
// Redirect its stdout and stderr to files
val stdout = new File(executorDir, "stdout")
- redirectStream(process.getInputStream, stdout)
+ CommandUtils.redirectStream(process.getInputStream, stdout)
val stderr = new File(executorDir, "stderr")
Files.write(header, stderr, Charsets.UTF_8)
- redirectStream(process.getErrorStream, stderr)
+ CommandUtils.redirectStream(process.getErrorStream, stderr)
// Wait for it to exit; this is actually a bad thing if it happens, because we expect to run
// long-lived processes only. However, in the future, we might restart the executor a few
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index fcaf4e92b1..5182dcbb2a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -26,10 +26,12 @@ import scala.concurrent.duration._
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
-import org.apache.spark.deploy.master.Master
+import org.apache.spark.deploy.master.{DriverState, Master}
+import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -44,6 +46,8 @@ private[spark] class Worker(
cores: Int,
memory: Int,
masterUrls: Array[String],
+ actorSystemName: String,
+ actorName: String,
workDirPath: String = null,
val conf: SparkConf)
extends Actor with Logging {
@@ -55,7 +59,7 @@ private[spark] class Worker(
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
- val HEARTBEAT_MILLIS = conf.get("spark.worker.timeout", "60").toLong * 1000 / 4
+ val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
@@ -68,6 +72,7 @@ private[spark] class Worker(
var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
+ val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName)
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
@@ -75,6 +80,9 @@ private[spark] class Worker(
var workDir: File = null
val executors = new HashMap[String, ExecutorRunner]
val finishedExecutors = new HashMap[String, ExecutorRunner]
+ val drivers = new HashMap[String, DriverRunner]
+ val finishedDrivers = new HashMap[String, DriverRunner]
+
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
@@ -185,7 +193,10 @@ private[spark] class Worker(
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
- sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
+ sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)
+
+ case Heartbeat =>
+ logInfo(s"Received heartbeat from driver ${sender.path}")
case RegisterWorkerFailed(message) =>
if (!registered) {
@@ -199,7 +210,7 @@ private[spark] class Worker(
} else {
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
+ self, workerId, host, new File(execSparkHome_), workDir, akkaUrl, ExecutorState.RUNNING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -219,8 +230,8 @@ private[spark] class Worker(
logInfo("Executor " + fullId + " finished with state " + state +
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
- finishedExecutors(fullId) = executor
executors -= fullId
+ finishedExecutors(fullId) = executor
coresUsed -= executor.cores
memoryUsed -= executor.memory
}
@@ -239,13 +250,52 @@ private[spark] class Worker(
}
}
+ case LaunchDriver(driverId, driverDesc) => {
+ logInfo(s"Asked to launch driver $driverId")
+ val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
+ drivers(driverId) = driver
+ driver.start()
+
+ coresUsed += driverDesc.cores
+ memoryUsed += driverDesc.mem
+ }
+
+ case KillDriver(driverId) => {
+ logInfo(s"Asked to kill driver $driverId")
+ drivers.get(driverId) match {
+ case Some(runner) =>
+ runner.kill()
+ case None =>
+ logError(s"Asked to kill unknown driver $driverId")
+ }
+ }
+
+ case DriverStateChanged(driverId, state, exception) => {
+ state match {
+ case DriverState.ERROR =>
+ logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
+ case DriverState.FINISHED =>
+ logInfo(s"Driver $driverId exited successfully")
+ case DriverState.KILLED =>
+ logInfo(s"Driver $driverId was killed by user")
+ }
+ masterLock.synchronized {
+ master ! DriverStateChanged(driverId, state, exception)
+ }
+ val driver = drivers.remove(driverId).get
+ finishedDrivers(driverId) = driver
+ memoryUsed -= driver.driverDesc.mem
+ coresUsed -= driver.driverDesc.cores
+ }
+
case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
logInfo(s"$x Disassociated !")
masterDisconnected()
case RequestWorkerState => {
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, activeMasterUrl, cores, memory,
+ finishedExecutors.values.toList, drivers.values.toList,
+ finishedDrivers.values.toList, activeMasterUrl, cores, memory,
coresUsed, memoryUsed, activeMasterWebUiUrl)
}
}
@@ -282,10 +332,11 @@ private[spark] object Worker {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val actorName = "Worker"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
conf = conf)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, workDir, conf), name = "Worker")
+ masterUrls, systemName, actorName, workDir, conf), name = actorName)
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
new file mode 100644
index 0000000000..0e0d0cd626
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -0,0 +1,55 @@
+package org.apache.spark.deploy.worker
+
+import akka.actor.{Actor, Address, AddressFromURIString}
+import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent}
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.DeployMessages.SendHeartbeat
+
+/**
+ * Actor which connects to a worker process and terminates the JVM if the connection is severed.
+ * Provides fate sharing between a worker and its associated child processes.
+ */
+private[spark] class WorkerWatcher(workerUrl: String) extends Actor
+ with Logging {
+ override def preStart() {
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+
+ logInfo(s"Connecting to worker $workerUrl")
+ val worker = context.actorSelection(workerUrl)
+ worker ! SendHeartbeat // need to send a message here to initiate connection
+ }
+
+ // Used to avoid shutting down JVM during tests
+ private[deploy] var isShutDown = false
+ private[deploy] def setTesting(testing: Boolean) = isTesting = testing
+ private var isTesting = false
+
+ // Lets us filter events only from the worker's actor system
+ private val expectedHostPort = AddressFromURIString(workerUrl).hostPort
+ private def isWorker(address: Address) = address.hostPort == expectedHostPort
+
+ def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
+
+ override def receive = {
+ case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
+ logInfo(s"Successfully connected to $workerUrl")
+
+ case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound)
+ if isWorker(remoteAddress) =>
+ // These logs may not be seen if the worker (and associated pipe) has died
+ logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
+ logError(s"Error was: $cause")
+ exitNonZero()
+
+ case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
+ // This log message will never be seen
+ logError(s"Lost connection to worker actor $workerUrl. Exiting.")
+ exitNonZero()
+
+ case e: AssociationEvent =>
+ // pass through association events relating to other remote actor systems
+
+ case e => logWarning(s"Received unexpected actor system event: $e")
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
index 0d59048313..925c6fb183 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -17,24 +17,20 @@
package org.apache.spark.deploy.worker.ui
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.Node
-
-import scala.concurrent.duration._
import scala.concurrent.Await
+import scala.xml.Node
import akka.pattern.ask
-
+import javax.servlet.http.HttpServletRequest
import net.liftweb.json.JsonAST.JValue
import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
-import org.apache.spark.deploy.worker.ExecutorRunner
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
-
private[spark] class IndexPage(parent: WorkerWebUI) {
val workerActor = parent.worker.self
val worker = parent.worker
@@ -56,6 +52,16 @@ private[spark] class IndexPage(parent: WorkerWebUI) {
val finishedExecutorTable =
UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors)
+ val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes")
+ val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse
+ val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers)
+ val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse
+ def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers)
+
+ // For now we only show driver information if the user has submitted drivers to the cluster.
+ // This is until we integrate the notion of drivers and applications in the UI.
+ def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0
+
val content =
<div class="row-fluid"> <!-- Worker Details -->
<div class="span12">
@@ -79,11 +85,33 @@ private[spark] class IndexPage(parent: WorkerWebUI) {
</div>
</div>
+ <div>
+ {if (hasDrivers)
+ <div class="row-fluid"> <!-- Running Drivers -->
+ <div class="span12">
+ <h4> Running Drivers {workerState.drivers.size} </h4>
+ {runningDriverTable}
+ </div>
+ </div>
+ }
+ </div>
+
<div class="row-fluid"> <!-- Finished Executors -->
<div class="span12">
<h4> Finished Executors </h4>
{finishedExecutorTable}
</div>
+ </div>
+
+ <div>
+ {if (hasDrivers)
+ <div class="row-fluid"> <!-- Finished Drivers -->
+ <div class="span12">
+ <h4> Finished Drivers </h4>
+ {finishedDriverTable}
+ </div>
+ </div>
+ }
</div>;
UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format(
@@ -111,6 +139,27 @@ private[spark] class IndexPage(parent: WorkerWebUI) {
.format(executor.appId, executor.execId)}>stderr</a>
</td>
</tr>
+
}
+ def driverRow(driver: DriverRunner): Seq[Node] = {
+ <tr>
+ <td>{driver.driverId}</td>
+ <td>{driver.driverDesc.command.arguments(1)}</td>
+ <td>{driver.finalState.getOrElse(DriverState.RUNNING)}</td>
+ <td sorttable_customkey={driver.driverDesc.cores.toString}>
+ {driver.driverDesc.cores.toString}
+ </td>
+ <td sorttable_customkey={driver.driverDesc.mem.toString}>
+ {Utils.megabytesToString(driver.driverDesc.mem)}
+ </td>
+ <td>
+ <a href={s"logPage?driverId=${driver.driverId}&logType=stdout"}>stdout</a>
+ <a href={s"logPage?driverId=${driver.driverId}&logType=stderr"}>stderr</a>
+ </td>
+ <td>
+ {driver.finalException.getOrElse("")}
+ </td>
+ </tr>
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index c382034c99..8daa47b2b2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -69,30 +69,48 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
def log(request: HttpServletRequest): String = {
val defaultBytes = 100 * 1024
- val appId = request.getParameter("appId")
- val executorId = request.getParameter("executorId")
+
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
- val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType)
+
+ val path = (appId, executorId, driverId) match {
+ case (Some(a), Some(e), None) =>
+ s"${workDir.getPath}/$appId/$executorId/$logType"
+ case (None, None, Some(d)) =>
+ s"${workDir.getPath}/$driverId/$logType"
+ case _ =>
+ throw new Exception("Request must specify either application or driver identifiers")
+ }
val (startByte, endByte) = getByteRange(path, offset, byteLength)
val file = new File(path)
val logLength = file.length
- val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n"
- .format(startByte, endByte, logLength, appId, executorId, logType)
+ val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n"
pre + Utils.offsetBytes(path, startByte, endByte)
}
def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = {
val defaultBytes = 100 * 1024
- val appId = request.getParameter("appId")
- val executorId = request.getParameter("executorId")
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
- val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType)
+
+ val (path, params) = (appId, executorId, driverId) match {
+ case (Some(a), Some(e), None) =>
+ (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e")
+ case (None, None, Some(d)) =>
+ (s"${workDir.getPath}/$d/$logType", s"driverId=$d")
+ case _ =>
+ throw new Exception("Request must specify either application or driver identifiers")
+ }
val (startByte, endByte) = getByteRange(path, offset, byteLength)
val file = new File(path)
@@ -106,9 +124,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val backButton =
if (startByte > 0) {
- <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s"
- .format(appId, executorId, logType, math.max(startByte-byteLength, 0),
- byteLength)}>
+ <a href={"?%s&logType=%s&offset=%s&byteLength=%s"
+ .format(params, logType, math.max(startByte-byteLength, 0), byteLength)}>
<button type="button" class="btn btn-default">
Previous {Utils.bytesToString(math.min(byteLength, startByte))}
</button>
@@ -122,8 +139,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val nextButton =
if (endByte < logLength) {
- <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s".
- format(appId, executorId, logType, endByte, byteLength)}>
+ <a href={"?%s&logType=%s&offset=%s&byteLength=%s".
+ format(params, logType, endByte, byteLength)}>
<button type="button" class="btn btn-default">
Next {Utils.bytesToString(math.min(byteLength, logLength-endByte))}
</button>
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 53a2b94a52..45b43b403d 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,8 +24,9 @@ import akka.remote._
import org.apache.spark.{SparkConf, SparkContext, Logging}
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
@@ -91,7 +92,8 @@ private[spark] class CoarseGrainedExecutorBackend(
}
private[spark] object CoarseGrainedExecutorBackend {
- def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ def run(driverUrl: String, executorId: String, hostname: String, cores: Int,
+ workerUrl: Option[String]) {
// Debug code
Utils.checkHost(hostname)
@@ -101,21 +103,27 @@ private[spark] object CoarseGrainedExecutorBackend {
indestructible = true, conf = new SparkConf)
// set it
val sparkHostPort = hostname + ":" + boundPort
-// conf.set("spark.hostPort", sparkHostPort)
actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
+ workerUrl.foreach{ url =>
+ actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
+ }
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
- if (args.length < 4) {
- //the reason we allow the last appid argument is to make it easy to kill rogue executors
- System.err.println(
- "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
- "[<appid>]")
- System.exit(1)
+ args.length match {
+ case x if x < 4 =>
+ System.err.println(
+ // Worker url is used in spark standalone mode to enforce fate-sharing with worker
+ "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> " +
+ "<cores> [<workerUrl>]")
+ System.exit(1)
+ case 4 =>
+ run(args(0), args(1), args(2), args(3).toInt, None)
+ case x if x > 4 =>
+ run(args(0), args(1), args(2), args(3).toInt, Some(args(4)))
}
- run(args(0), args(1), args(2), args(3).toInt)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e51d274d33..7f31d7e6f8 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -57,7 +57,7 @@ private[spark] class Executor(
Utils.setCustomHostname(slaveHostname)
// Set spark.* properties from executor arg
- val conf = new SparkConf(false)
+ val conf = new SparkConf(true)
conf.setAll(properties)
// If we are in yarn mode, systems can have different disk layouts so we must set it
@@ -279,6 +279,11 @@ private[spark] class Executor(
//System.exit(1)
}
} finally {
+ // TODO: Unregister shuffle memory only for ShuffleMapTask
+ val shuffleMemoryMap = env.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap.remove(Thread.currentThread().getId)
+ }
runningTasks.remove(taskId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index a1e98845f6..5980177320 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -71,7 +71,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
override def compressedOutputStream(s: OutputStream): OutputStream = {
- val blockSize = conf.get("spark.io.compression.snappy.block.size", "32768").toInt
+ val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768)
new SnappyOutputStream(s, blockSize)
}
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index f736bb3713..fb4c65909a 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -46,7 +46,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
- if (size == 0 && gotChunkForSendingOnce == false) {
+ if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
gotChunkForSendingOnce = true
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 95cb0206ac..cba8477ed5 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -330,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
- return true
+ true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
@@ -385,7 +385,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
}
// should not happen - to keep scala compiler happy
- return true
+ true
}
// This is a hack to determine if remote socket was closed or not.
@@ -559,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
// should not happen - to keep scala compiler happy
- return true
+ true
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 46c40d0a2a..e6e01783c8 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -54,22 +54,22 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
- conf.get("spark.core.connection.handler.threads.min", "20").toInt,
- conf.get("spark.core.connection.handler.threads.max", "60").toInt,
- conf.get("spark.core.connection.handler.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
+ conf.getInt("spark.core.connection.handler.threads.min", 20),
+ conf.getInt("spark.core.connection.handler.threads.max", 60),
+ conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val handleReadWriteExecutor = new ThreadPoolExecutor(
- conf.get("spark.core.connection.io.threads.min", "4").toInt,
- conf.get("spark.core.connection.io.threads.max", "32").toInt,
- conf.get("spark.core.connection.io.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
+ conf.getInt("spark.core.connection.io.threads.min", 4),
+ conf.getInt("spark.core.connection.io.threads.max", 32),
+ conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
- conf.get("spark.core.connection.connect.threads.min", "1").toInt,
- conf.get("spark.core.connection.connect.threads.max", "8").toInt,
- conf.get("spark.core.connection.connect.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
+ conf.getInt("spark.core.connection.connect.threads.min", 1),
+ conf.getInt("spark.core.connection.connect.threads.max", 8),
+ conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index f2ecc6d439..2612884bdb 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -61,7 +61,7 @@ private[spark] object Message {
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
- return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
@@ -69,9 +69,9 @@ private[spark] object Message {
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
- return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
} else {
- return createBufferMessage(Array(dataBuffer), ackId)
+ createBufferMessage(Array(dataBuffer), ackId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index b729eb11c5..d87157e12c 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -36,7 +36,7 @@ private[spark] class ShuffleCopier(conf: SparkConf) extends Logging {
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val connectTimeout = conf.get("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000)
val fc = new FileClient(handler, connectTimeout)
try {
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 546d921067..44204a8c46 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -64,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return new FileSegment(file, 0, file.length())
+ new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 6d4f46125f..83109d1a6f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -97,7 +97,7 @@ private[spark] object CheckpointRDD extends Logging {
throw new IOException("Checkpoint failed: temporary path " +
tempOutputPath + " already exists")
}
- val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
@@ -131,7 +131,7 @@ private[spark] object CheckpointRDD extends Logging {
): Iterator[T] = {
val env = SparkEnv.get
val fs = path.getFileSystem(broadcastedConf.value.value)
- val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 4ba4696fef..a73714abca 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
-import org.apache.spark.util.AppendOnlyMap
-
+import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
-class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
-
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
@@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+ // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // CoGroupValue is the intermediate state of each value before being merged in compute.
+ private type CoGroup = ArrayBuffer[Any]
+ private type CoGroupValue = (Any, Int) // Int is dependency number
+ private type CoGroupCombiner = Seq[CoGroup]
+
+ private val sparkConf = SparkEnv.get.conf
private var serializerClass: String = null
def setSerializer(cls: String): CoGroupedRDD[K] = {
@@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override val partitioner = Some(part)
- override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
+ val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
- // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ // A list of (rdd iterator, dependency number) pairs
+ val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
- rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
- getSeq(kv._1)(depNum) += kv._2
- }
+ val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+ rddIterators += ((it, depNum))
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
- kv => getSeq(kv._1)(depNum) += kv._2
+ val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ rddIterators += ((it, depNum))
+ }
+ }
+
+ if (!externalSorting) {
+ val map = new AppendOnlyMap[K, CoGroupCombiner]
+ val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup)
+ }
+ val getCombiner: K => CoGroupCombiner = key => {
+ map.changeValue(key, update)
+ }
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ getCombiner(kv._1)(depNum) += kv._2
}
}
+ new InterruptibleIterator(context, map.iterator)
+ } else {
+ val map = createExternalMap(numRdds)
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ map.insert(kv._1, new CoGroupValue(kv._2, depNum))
+ }
+ }
+ new InterruptibleIterator(context, map.iterator)
+ }
+ }
+
+ private def createExternalMap(numRdds: Int)
+ : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = {
+
+ val createCombiner: (CoGroupValue => CoGroupCombiner) = value => {
+ val newCombiner = Array.fill(numRdds)(new CoGroup)
+ value match { case (v, depNum) => newCombiner(depNum) += v }
+ newCombiner
}
- new InterruptibleIterator(context, map.iterator)
+ val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner =
+ (combiner, value) => {
+ value match { case (v, depNum) => combiner(depNum) += v }
+ combiner
+ }
+ val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner =
+ (combiner1, combiner2) => {
+ combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 }
+ }
+ new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner](
+ createCombiner, mergeValue, mergeCombiners)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 98da35763b..cefcc3d2d9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -295,10 +295,10 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
val prefPartActual = prefPart.get
- if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows
- return minPowerOfTwo // prefer balance over locality
- else {
- return prefPartActual // prefer locality over balance
+ if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows
+ minPowerOfTwo // prefer balance over locality
+ } else {
+ prefPartActual // prefer locality over balance
}
}
@@ -331,7 +331,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
*/
def run(): Array[PartitionGroup] = {
setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins)
- throwBalls() // assign partitions (balls) to each group (bins)
+ throwBalls() // assign partitions (balls) to each group (bins)
getPartitions
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 53f77a38f5..5cdb80be1d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -19,7 +19,10 @@ package org.apache.spark.rdd
import java.io.EOFException
-import org.apache.hadoop.mapred.FileInputFormat
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.{Configuration, Configurable}
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
@@ -31,7 +34,7 @@ import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
-import org.apache.hadoop.conf.{Configuration, Configurable}
+import org.apache.spark.util.Utils.cloneWritables
/**
@@ -42,14 +45,14 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
val inputSplit = new SerializableWritable[InputSplit](s)
- override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
+ override def hashCode(): Int = 41 * (41 + rddId) + idx
override val index: Int = idx
}
/**
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
- * sources in HBase, or S3).
+ * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`).
*
* @param sc The SparkContext to associate the RDD with.
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
@@ -61,15 +64,21 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
* @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
*/
-class HadoopRDD[K, V](
+class HadoopRDD[K: ClassTag, V: ClassTag](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int)
+ minSplits: Int,
+ cloneRecords: Boolean)
extends RDD[(K, V)](sc, Nil) with Logging {
def this(
@@ -78,7 +87,8 @@ class HadoopRDD[K, V](
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int) = {
+ minSplits: Int,
+ cloneRecords: Boolean) = {
this(
sc,
sc.broadcast(new SerializableWritable(conf))
@@ -87,7 +97,8 @@ class HadoopRDD[K, V](
inputFormatClass,
keyClass,
valueClass,
- minSplits)
+ minSplits,
+ cloneRecords)
}
protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
@@ -99,11 +110,11 @@ class HadoopRDD[K, V](
val conf: Configuration = broadcastedConf.value.value
if (conf.isInstanceOf[JobConf]) {
// A user-broadcasted JobConf was provided to the HadoopRDD, so always use it.
- return conf.asInstanceOf[JobConf]
+ conf.asInstanceOf[JobConf]
} else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
// getJobConf() has been called previously, so there is already a local cache of the JobConf
// needed by this RDD.
- return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
+ HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
} else {
// Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
@@ -111,7 +122,7 @@ class HadoopRDD[K, V](
val newJobConf = new JobConf(broadcastedConf.value.value)
initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- return newJobConf
+ newJobConf
}
}
@@ -127,7 +138,7 @@ class HadoopRDD[K, V](
newInputFormat.asInstanceOf[Configurable].setConf(conf)
}
HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat)
- return newInputFormat
+ newInputFormat
}
override def getPartitions: Array[Partition] = {
@@ -158,10 +169,10 @@ class HadoopRDD[K, V](
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
-
val key: K = reader.createKey()
+ val keyCloneFunc = cloneWritables[K](jobConf)
val value: V = reader.createValue()
-
+ val valueCloneFunc = cloneWritables[V](jobConf)
override def getNext() = {
try {
finished = !reader.next(key, value)
@@ -169,7 +180,11 @@ class HadoopRDD[K, V](
case eof: EOFException =>
finished = true
}
- (key, value)
+ if (cloneRecords) {
+ (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable]))
+ } else {
+ (key, value)
+ }
}
override def close() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 73d15b9082..992bd4aa0a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -20,11 +20,14 @@ package org.apache.spark.rdd
import java.text.SimpleDateFormat
import java.util.Date
+import scala.reflect.ClassTag
+
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.util.Utils.cloneWritables
private[spark]
@@ -33,15 +36,31 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS
val serializableHadoopSplit = new SerializableWritable(rawSplit)
- override def hashCode(): Int = (41 * (41 + rddId) + index)
+ override def hashCode(): Int = 41 * (41 + rddId) + index
}
-class NewHadoopRDD[K, V](
+/**
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param conf The Hadoop configuration.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
+ */
+class NewHadoopRDD[K: ClassTag, V: ClassTag](
sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- @transient conf: Configuration)
+ @transient conf: Configuration,
+ cloneRecords: Boolean)
extends RDD[(K, V)](sc, Nil)
with SparkHadoopMapReduceUtil
with Logging {
@@ -88,7 +107,8 @@ class NewHadoopRDD[K, V](
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
-
+ val keyCloneFunc = cloneWritables[K](conf)
+ val valueCloneFunc = cloneWritables[V](conf)
var havePair = false
var finished = false
@@ -105,7 +125,13 @@ class NewHadoopRDD[K, V](
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
- (reader.getCurrentKey, reader.getCurrentValue)
+ val key = reader.getCurrentKey
+ val value = reader.getCurrentValue
+ if (cloneRecords) {
+ (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable]))
+ } else {
+ (key, value)
+ }
}
private def close() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 2bf7c5b8d6..f6719ec57c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -18,35 +18,34 @@
package org.apache.spark.rdd
import java.nio.ByteBuffer
-import java.util.Date
import java.text.SimpleDateFormat
+import java.util.Date
import java.util.{HashMap => JHashMap}
-import scala.collection.{mutable, Map}
+import scala.collection.Map
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.reflect.{ClassTag, classTag}
-import org.apache.hadoop.mapred._
-import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.SequenceFile.CompressionType
-import org.apache.hadoop.mapred.FileOutputFormat
-import org.apache.hadoop.mapred.OutputFormat
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
+// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark.
+import org.apache.hadoop.mapred.SparkHadoopWriter
+import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
-import org.apache.spark.Aggregator
-import org.apache.spark.Partitioner
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.util.SerializableHyperLogLog
@@ -100,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- // A sanity check to make sure mergeCombiners is not defined.
- assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
@@ -120,9 +117,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
/**
- * Merge the values for each key using an associative function and a neutral "zero value" which may
- * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
- * list concatenation, 0 for addition, or 1 for multiplication.).
+ * Merge the values for each key using an associative function and a neutral "zero value" which
+ * may be added to the result an arbitrary number of times, and must not change the result
+ * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
@@ -138,18 +135,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
/**
- * Merge the values for each key using an associative function and a neutral "zero value" which may
- * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
- * list concatenation, 0 for addition, or 1 for multiplication.).
+ * Merge the values for each key using an associative function and a neutral "zero value" which
+ * may be added to the result an arbitrary number of times, and must not change the result
+ * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
}
/**
- * Merge the values for each key using an associative function and a neutral "zero value" which may
- * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
- * list concatenation, 0 for addition, or 1 for multiplication.).
+ * Merge the values for each key using an associative function and a neutral "zero value" which
+ * may be added to the result an arbitrary number of times, and must not change the result
+ * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
foldByKey(zeroValue, defaultPartitioner(self))(func)
@@ -226,7 +223,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
/**
- * Return approximate number of distinct values for each key in this RDD.
+ * Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
@@ -268,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
+ createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -340,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* existing partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
- : RDD[(K, C)] = {
+ : RDD[(K, C)] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
}
@@ -579,7 +577,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec)
+ val runtimeClass = fm.runtimeClass
+ saveAsHadoopFile(path, getKeyClass, getValueClass, runtimeClass.asInstanceOf[Class[F]], codec)
}
/**
@@ -599,7 +598,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration = self.context.hadoopConfiguration) {
+ conf: Configuration = self.context.hadoopConfiguration)
+ {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -668,7 +668,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
- // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
+ // Doesn't work in Scala 2.9 due to what may be a generics bug
+ // TODO: Should we uncomment this for Scala 2.10?
+ // conf.setOutputFormat(outputFormatClass)
conf.set("mapred.output.format.class", outputFormatClass.getName)
for (c <- codec) {
conf.setCompressMapOutput(true)
@@ -702,7 +704,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
throw new SparkException("Output value class not set")
}
- logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")")
+ logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
+ valueClass.getSimpleName+ ")")
val writer = new SparkHadoopWriter(conf)
writer.preSetup()
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 1dbbe39898..d4f396afb5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -96,7 +96,7 @@ class PipedRDD[T: ClassTag](
// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream).getLines
- return new Iterator[String] {
+ new Iterator[String] {
def next() = lines.next()
def hasNext = {
if (lines.hasNext) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 2142ae730e..cd90a1561a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -23,7 +23,6 @@ import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
import scala.reflect.{classTag, ClassTag}
import org.apache.hadoop.io.BytesWritable
@@ -52,11 +51,13 @@ import org.apache.spark._
* partitioned collection of elements that can be operated on in parallel. This class contains the
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
* [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs of key-value
- * pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] contains
- * operations available only on RDDs of Doubles; and [[org.apache.spark.rdd.SequenceFileRDDFunctions]]
- * contains operations available on RDDs that can be saved as SequenceFiles. These operations are
- * automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit
- * conversions when you `import org.apache.spark.SparkContext._`.
+ * pairs, such as `groupByKey` and `join`;
+ * [[org.apache.spark.rdd.DoubleRDDFunctions]] contains operations available only on RDDs of
+ * Doubles; and
+ * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that
+ * can be saved as SequenceFiles.
+ * These operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]
+ * through implicit conversions when you `import org.apache.spark.SparkContext._`.
*
* Internally, each RDD is characterized by five main properties:
*
@@ -235,12 +236,9 @@ abstract class RDD[T: ClassTag](
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
- private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
- if (isCheckpointed) {
- firstParent[T].iterator(split, context)
- } else {
- compute(split, context)
- }
+ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
+ {
+ if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)
}
// Transformations (return a new RDD)
@@ -268,6 +266,9 @@ abstract class RDD[T: ClassTag](
def distinct(numPartitions: Int): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
def distinct(): RDD[T] = distinct(partitions.size)
/**
@@ -280,7 +281,7 @@ abstract class RDD[T: ClassTag](
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): RDD[T] = {
- coalesce(numPartitions, true)
+ coalesce(numPartitions, shuffle = true)
}
/**
@@ -651,7 +652,8 @@ abstract class RDD[T: ClassTag](
}
/**
- * Reduces the elements of this RDD using the specified commutative and associative binary operator.
+ * Reduces the elements of this RDD using the specified commutative and
+ * associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
@@ -767,7 +769,7 @@ abstract class RDD[T: ClassTag](
val entry = iter.next()
m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
}
- return m1
+ m1
}
val myResult = mapPartitions(countPartition).reduce(mergeMaps)
myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
@@ -845,7 +847,7 @@ abstract class RDD[T: ClassTag](
partsScanned += numPartsToTry
}
- return buf.toArray
+ buf.toArray
}
/**
@@ -958,7 +960,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- @transient private[spark] val origin = sc.getCallSite
+ @transient private[spark] val origin = sc.getCallSite()
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 043e01dbfb..7046c06d20 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -106,7 +106,7 @@ class DAGScheduler(
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50.milliseconds
+ val RESUBMIT_TIMEOUT = 200.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
@@ -133,7 +133,8 @@ class DAGScheduler(
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
- private[spark] val listenerBus = new SparkListenerBus()
+ // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped.
+ private[spark] val listenerBus = new SparkListenerBus
// Contains the locations that each RDD's partitions are cached on
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
@@ -196,7 +197,7 @@ class DAGScheduler(
*/
def receive = {
case event: DAGSchedulerEvent =>
- logDebug("Got event of type " + event.getClass.getName)
+ logTrace("Got event of type " + event.getClass.getName)
/**
* All events are forwarded to `processEvent()`, so that the event processing logic can
@@ -1121,5 +1122,6 @@ class DAGScheduler(
}
metadataCleaner.cancel()
taskSched.stop()
+ listenerBus.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 90eb8a747f..cc10cc0849 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
}
- return retval.toSet
+ retval.toSet
}
// This method does not expect failures, since validate has already passed ...
@@ -121,18 +121,18 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
)
- return retval.toSet
+ retval.toSet
}
private def findPreferredLocations(): Set[SplitInfo] = {
logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
", inputFormatClazz : " + inputFormatClazz)
if (mapreduceInputFormat) {
- return prefLocsFromMapreduceInputFormat()
+ prefLocsFromMapreduceInputFormat()
}
else {
assert(mapredInputFormat)
- return prefLocsFromMapredInputFormat()
+ prefLocsFromMapredInputFormat()
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 1791242215..4bc13c23d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -75,12 +75,12 @@ private[spark] class Pool(
return schedulableNameToSchedulable(schedulableName)
}
for (schedulable <- schedulableQueue) {
- var sched = schedulable.getSchedulableByName(schedulableName)
+ val sched = schedulable.getSchedulableByName(schedulableName)
if (sched != null) {
return sched
}
}
- return null
+ null
}
override def executorLost(executorId: String, host: String) {
@@ -92,7 +92,7 @@ private[spark] class Pool(
for (schedulable <- schedulableQueue) {
shouldRevive |= schedulable.checkSpeculatableTasks()
}
- return shouldRevive
+ shouldRevive
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
@@ -101,7 +101,7 @@ private[spark] class Pool(
for (schedulable <- sortedSchedulableQueue) {
sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
}
- return sortedTaskSetQueue
+ sortedTaskSetQueue
}
def increaseRunningTasks(taskNum: Int) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
index 3418640b8c..5e62c8468f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
@@ -37,9 +37,9 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
res = math.signum(stageId1 - stageId2)
}
if (res < 0) {
- return true
+ true
} else {
- return false
+ false
}
}
}
@@ -56,7 +56,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
- var res:Boolean = true
var compare:Int = 0
if (s1Needy && !s2Needy) {
@@ -70,11 +69,11 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
}
if (compare < 0) {
- return true
+ true
} else if (compare > 0) {
- return false
+ false
} else {
- return s1.name < s2.name
+ s1.name < s2.name
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 627995c826..55a40a92c9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -43,6 +43,9 @@ case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], propertie
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
extends SparkListenerEvents
+/** An event used in the listener to shutdown the listener daemon thread. */
+private[scheduler] case object SparkListenerShutdown extends SparkListenerEvents
+
trait SparkListener {
/**
* Called when a stage is completed, with information on the completed stage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index e7defd768b..17b1328b86 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -24,15 +24,17 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.Logging
/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
-private[spark] class SparkListenerBus() extends Logging {
- private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener]
+private[spark] class SparkListenerBus extends Logging {
+ private val sparkListeners = new ArrayBuffer[SparkListener] with SynchronizedBuffer[SparkListener]
/* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
* an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
+ private val EVENT_QUEUE_CAPACITY = 10000
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
+ // Create a new daemon thread to listen for events. This thread is stopped when it receives
+ // a SparkListenerShutdown event, using the stop method.
new Thread("SparkListenerBus") {
setDaemon(true)
override def run() {
@@ -53,6 +55,9 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case SparkListenerShutdown =>
+ // Get out of the while loop and shutdown the daemon thread
+ return
case _ =>
}
}
@@ -80,7 +85,7 @@ private[spark] class SparkListenerBus() extends Logging {
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty()) {
+ while (!eventQueue.isEmpty) {
if (System.currentTimeMillis > finishTime) {
return false
}
@@ -88,6 +93,8 @@ private[spark] class SparkListenerBus() extends Logging {
* add overhead in the general case. */
Thread.sleep(10)
}
- return true
+ true
}
+
+ def stop(): Unit = post(SparkListenerShutdown)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 7cb3fe46e5..c60e9896de 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -96,7 +96,7 @@ private[spark] class Stage(
def newAttemptId(): Int = {
val id = nextAttemptId
nextAttemptId += 1
- return id
+ id
}
val name = callSite.getOrElse(rdd.origin)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index e80cc6b0f6..9d3e615826 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -74,6 +74,6 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
def value(): T = {
val resultSer = SparkEnv.get.serializer.newInstance()
- return resultSer.deserialize(valueBytes)
+ resultSer.deserialize(valueBytes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index e22b1e53e8..35e9544718 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -31,13 +31,13 @@ import org.apache.spark.util.Utils
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
extends Logging {
- private val THREADS = sparkEnv.conf.get("spark.resultGetter.threads", "4").toInt
+ private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
THREADS, "Result resolver thread")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
- return sparkEnv.closureSerializer.newInstance()
+ sparkEnv.closureSerializer.newInstance()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 0c8ed62759..d4f74d3e18 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -51,15 +51,15 @@ private[spark] class TaskSchedulerImpl(
isLocal: Boolean = false)
extends TaskScheduler with Logging
{
- def this(sc: SparkContext) = this(sc, sc.conf.get("spark.task.maxFailures", "4").toInt)
+ def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4))
val conf = sc.conf
// How often to check for speculative tasks
- val SPECULATION_INTERVAL = conf.get("spark.speculation.interval", "100").toLong
+ val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100)
// Threshold above which we warn user initial TaskSet may be starved
- val STARVATION_TIMEOUT = conf.get("spark.starvation.timeout", "15000").toLong
+ val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000)
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
@@ -125,7 +125,7 @@ private[spark] class TaskSchedulerImpl(
override def start() {
backend.start()
- if (!isLocal && conf.get("spark.speculation", "false").toBoolean) {
+ if (!isLocal && conf.getBoolean("spark.speculation", false)) {
logInfo("Starting speculative execution thread")
import sc.env.actorSystem.dispatcher
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 6dd1469d8f..fc0ee07089 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -57,11 +57,11 @@ private[spark] class TaskSetManager(
val conf = sched.sc.conf
// CPUs to request per task
- val CPUS_PER_TASK = conf.get("spark.task.cpus", "1").toInt
+ val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)
// Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = conf.get("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = conf.get("spark.speculation.multiplier", "1.5").toDouble
+ val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
+ val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
// Serializer for closures and tasks.
val env = SparkEnv.get
@@ -116,7 +116,7 @@ private[spark] class TaskSetManager(
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
- conf.get("spark.logging.exceptionPrintInterval", "10000").toLong
+ conf.getLong("spark.logging.exceptionPrintInterval", 10000)
// Map of recent exceptions (identified by string representation and top stack frame) to
// duplicate count (how many times the same exception has appeared) and time the full exception
@@ -228,7 +228,7 @@ private[spark] class TaskSetManager(
return Some(index)
}
}
- return None
+ None
}
/** Check whether a task is currently running an attempt on a given host */
@@ -291,7 +291,7 @@ private[spark] class TaskSetManager(
}
}
- return None
+ None
}
/**
@@ -332,7 +332,7 @@ private[spark] class TaskSetManager(
}
// Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(execId, host, locality)
+ findSpeculativeTask(execId, host, locality)
}
/**
@@ -387,7 +387,7 @@ private[spark] class TaskSetManager(
case _ =>
}
}
- return None
+ None
}
/**
@@ -584,7 +584,7 @@ private[spark] class TaskSetManager(
}
override def getSchedulableByName(name: String): Schedulable = {
- return null
+ null
}
override def addSchedulable(schedulable: Schedulable) {}
@@ -594,7 +594,7 @@ private[spark] class TaskSetManager(
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
sortedTaskSetQueue += this
- return sortedTaskSetQueue
+ sortedTaskSetQueue
}
/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
@@ -669,7 +669,7 @@ private[spark] class TaskSetManager(
}
}
}
- return foundTasks
+ foundTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 2f5bcafe40..0208388e86 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -63,7 +63,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
// Periodically revive offers to allow delay scheduling to work
- val reviveInterval = conf.get("spark.scheduler.revive.interval", "1000").toLong
+ val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000)
import context.dispatcher
context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
}
@@ -165,7 +165,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
override def start() {
val properties = new ArrayBuffer[(String, String)]
for ((key, value) <- scheduler.sc.conf.getAll) {
- if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
+ if (key.startsWith("spark.")) {
properties += ((key, value))
}
}
@@ -209,8 +209,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
override def defaultParallelism(): Int = {
- conf.getOption("spark.default.parallelism").map(_.toInt).getOrElse(
- math.max(totalCoreCount.get(), 2))
+ conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
}
// Called by subclasses when notified of a lost worker
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index b44d1e43c8..d99c76117c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -33,7 +33,7 @@ private[spark] class SimrSchedulerBackend(
val tmpPath = new Path(driverFilePath + "_tmp")
val filePath = new Path(driverFilePath)
- val maxCores = conf.get("spark.simr.executor.cores", "1").toInt
+ val maxCores = conf.getInt("spark.simr.executor.cores", 1)
override def start() {
super.start()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 73fc37444e..faa6e1ebe8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster
import scala.collection.mutable.HashMap
import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.deploy.client.{Client, ClientListener}
+import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.deploy.{Command, ApplicationDescription}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
import org.apache.spark.util.Utils
@@ -31,10 +31,10 @@ private[spark] class SparkDeploySchedulerBackend(
masters: Array[String],
appName: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
- with ClientListener
+ with AppClientListener
with Logging {
- var client: Client = null
+ var client: AppClient = null
var stopping = false
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
@@ -47,14 +47,14 @@ private[spark] class SparkDeploySchedulerBackend(
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
conf.get("spark.driver.host"), conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}")
val command = Command(
"org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
- client = new Client(sc.env.actorSystem, masters, appDesc, this, conf)
+ client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d46fceba89..c27049bdb5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -77,7 +77,7 @@ private[spark] class CoarseMesosSchedulerBackend(
"Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor"))
- val extraCoresPerSlave = conf.get("spark.mesos.extra.cores", "0").toInt
+ val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0)
var nextMesosTaskId = 0
@@ -140,7 +140,7 @@ private[spark] class CoarseMesosSchedulerBackend(
.format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
- return command.build()
+ command.build()
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index ae8d527352..49781485d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -141,13 +141,13 @@ private[spark] class MesosSchedulerBackend(
// Serialize the map as an array of (String, String) pairs
execArgs = Utils.serialize(props.toArray)
}
- return execArgs
+ execArgs
}
private def setClassLoader(): ClassLoader = {
val oldClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(classLoader)
- return oldClassLoader
+ oldClassLoader
}
private def restoreClassLoader(oldClassLoader: ClassLoader) {
@@ -255,7 +255,7 @@ private[spark] class MesosSchedulerBackend(
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(1).build())
.build()
- return MesosTaskInfo.newBuilder()
+ MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
.setExecutor(createExecutorInfo(slaveId))
@@ -340,5 +340,5 @@ private[spark] class MesosSchedulerBackend(
}
// TODO: query Mesos for number of cores
- override def defaultParallelism() = sc.conf.get("spark.default.parallelism", "8").toInt
+ override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8)
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index a24a3b04b8..c14cd47556 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -36,7 +36,7 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
*/
class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
private val bufferSize = {
- conf.get("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
}
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -48,7 +48,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
- kryo.setReferences(conf.get("spark.kryo.referenceTracking", "true").toBoolean)
+ kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
for (cls <- KryoSerializer.toRegister) kryo.register(cls)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 47478631a1..4fa2ab96d9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -327,7 +327,7 @@ object BlockFetcherIterator {
fetchRequestsSync.put(request)
}
- copiers = startCopiers(conf.get("spark.shuffle.copier.threads", "6").toInt)
+ copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
Utils.getUsedTimeMs(startTime))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7156d855d8..301d784b35 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -17,12 +17,14 @@
package org.apache.spark.storage
+import java.util.UUID
+
/**
* Identifies a particular Block of data, usually associated with a single file.
* A Block can be uniquely identified by its filename, but each type of Block has a different
* set of keys which produce its unique name.
*
- * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method.
*/
private[spark] sealed abstract class BlockId {
/** A globally unique identifier for this Block. Can be used for ser/de. */
@@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
-private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+private[spark]
+case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
}
@@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B
def name = "input-" + streamId + "-" + uniqueId
}
+/** Id associated with temporary data managed as blocks. Not serializable. */
+private[spark] case class TempBlockId(id: UUID) extends BlockId {
+ def name = "temp_" + id
+}
+
// Intended only for testing purposes
private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 6d2cda97b0..6f1345c57a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -58,8 +58,8 @@ private[spark] class BlockManager(
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
- val useNetty = conf.get("spark.shuffle.use.netty", "false").toBoolean
- val nettyPortConfig = conf.get("spark.shuffle.sender.port", "0").toInt
+ val useNetty = conf.getBoolean("spark.shuffle.use.netty", false)
+ val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0)
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
@@ -72,19 +72,17 @@ private[spark] class BlockManager(
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
val maxBytesInFlight =
- conf.get("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+ conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024
// Whether to compress broadcast variables that are stored
- val compressBroadcast = conf.get("spark.broadcast.compress", "true").toBoolean
+ val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
// Whether to compress shuffle output that are stored
- val compressShuffle = conf.get("spark.shuffle.compress", "true").toBoolean
+ val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
// Whether to compress RDD partitions that are stored serialized
- val compressRdds = conf.get("spark.rdd.compress", "false").toBoolean
+ val compressRdds = conf.getBoolean("spark.rdd.compress", false)
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val hostPort = Utils.localHostPort(conf)
-
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -159,7 +157,7 @@ private[spark] class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
- * thread if our heartbeat to the block amnager indicates that we were not registered.
+ * thread if our heartbeat to the block manager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/
@@ -412,7 +410,7 @@ private[spark] class BlockManager(
logDebug("The value of block " + blockId + " is null")
}
logDebug("Block " + blockId + " not found")
- return None
+ None
}
/**
@@ -443,7 +441,7 @@ private[spark] class BlockManager(
: BlockFetcherIterator = {
val iter =
- if (conf.get("spark.shuffle.use.netty", "false").toBoolean) {
+ if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
@@ -469,7 +467,7 @@ private[spark] class BlockManager(
def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
- val syncWrites = conf.get("spark.shuffle.sync", "false").toBoolean
+ val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
}
@@ -864,15 +862,15 @@ private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
def getMaxMemory(conf: SparkConf): Long = {
- val memoryFraction = conf.get("spark.storage.memoryFraction", "0.66").toDouble
+ val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6)
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
}
def getHeartBeatFrequency(conf: SparkConf): Long =
- conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
+ conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4
def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean =
- conf.get("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
+ conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false)
/**
* Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 51a29ed8ef..c54e4f2664 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -30,8 +30,8 @@ import org.apache.spark.util.AkkaUtils
private[spark]
class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Logging {
- val AKKA_RETRY_ATTEMPTS: Int = conf.get("spark.akka.num.retries", "3").toInt
- val AKKA_RETRY_INTERVAL_MS: Int = conf.get("spark.akka.retry.wait", "3000").toInt
+ val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3)
+ val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000)
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 58452d9657..2c1a4e2f5d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -348,14 +348,19 @@ object BlockManagerMasterActor {
if (storageLevel.isValid) {
// isValid means it is either stored in-memory or on-disk.
- _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
+ // But the memSize here indicates the data size in or dropped from memory,
+ // and the diskSize here indicates the data size in or dropped to disk.
+ // They can be both larger than 0, when a block is dropped from memory to disk.
+ // Therefore, a safe way to set BlockStatus is to set its info in accurate modes.
if (storageLevel.useMemory) {
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0))
_remainingMem -= memSize
logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
Utils.bytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize))
logInfo("Added %s on disk on %s (size: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 21f003609b..42f52d7b26 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -42,15 +42,15 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
logDebug("Parsed as a block message array")
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
- return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message", e)
- return None
+ None
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
- return None
+ None
}
}
}
@@ -61,7 +61,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
- return None
+ None
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
@@ -70,9 +70,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
if (buffer == null) {
return None
}
- return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+ Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
}
- case _ => return None
+ case _ => None
}
}
@@ -93,7 +93,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- return buffer
+ buffer
}
}
@@ -111,7 +111,7 @@ private[spark] object BlockManagerWorker extends Logging {
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
- return (resultMessage != None)
+ resultMessage != None
}
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
@@ -130,8 +130,8 @@ private[spark] object BlockManagerWorker extends Logging {
return blockMessage.getData
})
}
- case None => logDebug("No response message received"); return null
+ case None => logDebug("No response message received")
}
- return null
+ null
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index 80dcb5a207..fbafcf79d2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -154,7 +154,7 @@ private[spark] class BlockMessage() {
println()
*/
val finishTime = System.currentTimeMillis
- return Message.createBufferMessage(buffers)
+ Message.createBufferMessage(buffers)
}
override def toString: String = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index a06f50a0ac..59329361f3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -96,7 +96,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM
println()
println()
*/
- return Message.createBufferMessage(buffers)
+ Message.createBufferMessage(buffers)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 61e63c60d5..369a277232 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -181,4 +181,8 @@ class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting
+
+ def bytesWritten: Long = {
+ lastValidPosition - initialPosition
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 55dcb3742c..a8ef7fa8b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
-import java.util.{Date, Random}
+import java.util.{Date, Random, UUID}
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
@@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = shuffleManager.conf.get("spark.diskStore.subDirectories", "64").toInt
+ private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64)
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Produces a unique block id and File suitable for intermediate results. */
+ def createTempBlock(): (TempBlockId, File) = {
+ var blockId = new TempBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 05f676c6e2..27f057b9f2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -245,7 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return false
}
}
- return true
+ true
}
override def contains(blockId: BlockId): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 39dc7bb19a..e2b24298a5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -64,9 +64,9 @@ class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
- conf.get("spark.shuffle.consolidateFiles", "false").toBoolean
+ conf.getBoolean("spark.shuffle.consolidateFiles", false)
- private val bufferSize = conf.get("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
/**
* Contains all the state related to a particular shuffle. This includes a pool of unused
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index b5596dffd3..0f84810d6b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -74,7 +74,7 @@ class StorageLevel private(
if (deserialized_) {
ret |= 1
}
- return ret
+ ret
}
override def writeExternal(out: ObjectOutput) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index b7b87250b9..bcd2824450 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler._
*/
private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener {
// How many stages to remember
- val RETAINED_STAGES = sc.conf.get("spark.ui.retained_stages", "1000").toInt
+ val RETAINED_STAGES = sc.conf.getInt("spark.ui.retainedStages", 1000)
val DEFAULT_POOL_NAME = "default"
val stageIdToPool = new HashMap[Int, String]()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 8dcfeacb60..d1e58016be 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -171,7 +171,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
- <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
+ <h4>Aggregated Metrics by Executor</h4> ++ executorTable.toNodeSeq() ++
<h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 463d85dfd5..9ad6de3c6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
{if (isFairScheduler) {<th>Pool Name</th>} else {}}
<th>Description</th>
<th>Submitted</th>
- <th>Task Time</th>
+ <th>Duration</th>
<th>Tasks: Succeeded/Total</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 3f009a8998..761d378c7f 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -44,13 +44,13 @@ private[spark] object AkkaUtils {
def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
conf: SparkConf): (ActorSystem, Int) = {
- val akkaThreads = conf.get("spark.akka.threads", "4").toInt
- val akkaBatchSize = conf.get("spark.akka.batchSize", "15").toInt
+ val akkaThreads = conf.getInt("spark.akka.threads", 4)
+ val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
- val akkaTimeout = conf.get("spark.akka.timeout", "100").toInt
+ val akkaTimeout = conf.getInt("spark.akka.timeout", 100)
- val akkaFrameSize = conf.get("spark.akka.frameSize", "10").toInt
- val akkaLogLifecycleEvents = conf.get("spark.akka.logLifecycleEvents", "false").toBoolean
+ val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
+ val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
if (!akkaLogLifecycleEvents) {
// As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent.
@@ -58,12 +58,12 @@ private[spark] object AkkaUtils {
Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL))
}
- val logAkkaConfig = if (conf.get("spark.akka.logAkkaConfig", "false").toBoolean) "on" else "off"
+ val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off"
- val akkaHeartBeatPauses = conf.get("spark.akka.heartbeat.pauses", "600").toInt
+ val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600)
val akkaFailureDetector =
- conf.get("spark.akka.failure-detector.threshold", "300.0").toDouble
- val akkaHeartBeatInterval = conf.get("spark.akka.heartbeat.interval", "1000").toInt
+ conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
+ val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback(
ConfigFactory.parseString(
@@ -103,7 +103,7 @@ private[spark] object AkkaUtils {
/** Returns the default Spark timeout to use for Akka ask operations. */
def askTimeout(conf: SparkConf): FiniteDuration = {
- Duration.create(conf.get("spark.akka.askTimeout", "30").toLong, "seconds")
+ Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds")
}
/** Returns the default Spark timeout to use for Akka remote actor lookup. */
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 7108595e3e..1df6b87fb0 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -61,7 +61,7 @@ private[spark] object ClosureCleaner extends Logging {
return f.getType :: Nil // Stop at the first $outer that is not a closure
}
}
- return Nil
+ Nil
}
// Get a list of the outer objects for a given closure object.
@@ -74,7 +74,7 @@ private[spark] object ClosureCleaner extends Logging {
return f.get(obj) :: Nil // Stop at the first $outer that is not a closure
}
}
- return Nil
+ Nil
}
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
@@ -174,7 +174,7 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(obj, outer)
}
- return obj
+ obj
}
}
}
@@ -182,7 +182,7 @@ private[spark] object ClosureCleaner extends Logging {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@@ -215,7 +215,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index aa7f52cafb..ac07a55cb9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -74,7 +74,7 @@ object MetadataCleanerType extends Enumeration {
// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
object MetadataCleaner {
def getDelaySeconds(conf: SparkConf) = {
- conf.get("spark.cleaner.ttl", "3500").toInt
+ conf.getInt("spark.cleaner.ttl", -1)
}
def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index bddb3bb735..3cf94892e9 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -108,7 +108,7 @@ private[spark] object SizeEstimator extends Logging {
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, hotSpotMBeanClass)
// TODO: We could use reflection on the VMOption returned ?
- return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
+ getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
@@ -141,7 +141,7 @@ private[spark] object SizeEstimator extends Logging {
def dequeue(): AnyRef = {
val elem = stack.last
stack.trimEnd(1)
- return elem
+ elem
}
}
@@ -162,7 +162,7 @@ private[spark] object SizeEstimator extends Logging {
while (!state.isFinished) {
visitSingleObject(state.dequeue(), state)
}
- return state.size
+ state.size
}
private def visitSingleObject(obj: AnyRef, state: SearchState) {
@@ -276,11 +276,11 @@ private[spark] object SizeEstimator extends Logging {
// Create and cache a new ClassInfo
val newInfo = new ClassInfo(shellSize, pointerFields)
classInfos.put(cls, newInfo)
- return newInfo
+ newInfo
}
private def alignSize(size: Long): Long = {
val rem = size % ALIGN_SIZE
- return if (rem == 0) size else (size + ALIGN_SIZE - rem)
+ if (rem == 0) size else (size + ALIGN_SIZE - rem)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 181ae2fd45..8e07a0f29a 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -26,16 +26,23 @@ import org.apache.spark.Logging
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
- * time stamp along with each key-value pair. Key-value pairs that are older than a particular
- * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
- * replacement of scala.collection.mutable.HashMap.
+ * timestamp along with each key-value pair. If specified, the timestamp of each pair can be
+ * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
+ * threshold time can then be removed using the clearOldValues method. This is intended to
+ * be a drop-in replacement of scala.collection.mutable.HashMap.
+ * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
+ * updated when it is accessed
*/
-class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
+class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends Map[A, B]() with Logging {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
- if (value != null) Some(value._1) else None
+ if (value != null && updateTimeStampOnGet) {
+ internalMap.replace(key, value, (value._1, currentTime))
+ }
+ Option(value).map(_._1)
}
def iterator: Iterator[(A, B)] = {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 5f1253100b..caa9bf4c92 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -26,37 +26,61 @@ import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import scala.reflect.ClassTag
+import scala.reflect.{classTag, ClassTag}
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
+import org.apache.hadoop.io._
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging}
+import org.apache.spark.{SparkConf, SparkException, Logging}
/**
* Various utility methods used by Spark.
*/
private[spark] object Utils extends Logging {
+
+ /**
+ * We try to clone for most common types of writables and we call WritableUtils.clone otherwise
+ * intention is to optimize, for example for NullWritable there is no need and for Long, int and
+ * String creating a new object with value set would be faster.
+ */
+ def cloneWritables[T: ClassTag](conf: Configuration): Writable => T = {
+ val cloneFunc = classTag[T] match {
+ case ClassTag(_: Text) =>
+ (w: Writable) => new Text(w.asInstanceOf[Text].getBytes).asInstanceOf[T]
+ case ClassTag(_: LongWritable) =>
+ (w: Writable) => new LongWritable(w.asInstanceOf[LongWritable].get).asInstanceOf[T]
+ case ClassTag(_: IntWritable) =>
+ (w: Writable) => new IntWritable(w.asInstanceOf[IntWritable].get).asInstanceOf[T]
+ case ClassTag(_: NullWritable) =>
+ (w: Writable) => w.asInstanceOf[T] // TODO: should we clone this ?
+ case _ =>
+ (w: Writable) => WritableUtils.clone(w, conf).asInstanceOf[T] // slower way of cloning.
+ }
+ cloneFunc
+ }
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close()
- return bos.toByteArray
+ bos.toByteArray
}
/** Deserialize an object using Java serialization */
def deserialize[T](bytes: Array[Byte]): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis)
- return ois.readObject.asInstanceOf[T]
+ ois.readObject.asInstanceOf[T]
}
/** Deserialize an object using Java serialization and the given ClassLoader */
@@ -66,7 +90,7 @@ private[spark] object Utils extends Logging {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
- return ois.readObject.asInstanceOf[T]
+ ois.readObject.asInstanceOf[T]
}
/** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
@@ -144,7 +168,7 @@ private[spark] object Utils extends Logging {
i += 1
}
}
- return buf
+ buf
}
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
@@ -396,15 +420,6 @@ private[spark] object Utils extends Logging {
InetAddress.getByName(address).getHostName
}
- def localHostPort(conf: SparkConf): String = {
- val retval = conf.get("spark.hostPort", null)
- if (retval == null) {
- logErrorWithStack("spark.hostPort not set but invoking localHostPort")
- return localHostName()
- }
- retval
- }
-
def checkHost(host: String, message: String = "") {
assert(host.indexOf(':') == -1, message)
}
@@ -413,14 +428,6 @@ private[spark] object Utils extends Logging {
assert(hostPort.indexOf(':') != -1, message)
}
- def logErrorWithStack(msg: String) {
- try {
- throw new Exception
- } catch {
- case ex: Exception => logError(msg, ex)
- }
- }
-
// Typically, this will be of order of number of nodes in cluster
// If not, we should change it to LRUCache or something.
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
@@ -428,7 +435,7 @@ private[spark] object Utils extends Logging {
def parseHostPort(hostPort: String): (String, Int) = {
{
// Check cache first.
- var cached = hostPortParseResults.get(hostPort)
+ val cached = hostPortParseResults.get(hostPort)
if (cached != null) return cached
}
@@ -731,7 +738,7 @@ private[spark] object Utils extends Logging {
} catch {
case ise: IllegalStateException => return true
}
- return false
+ false
}
def isSpace(c: Char): Boolean = {
@@ -748,7 +755,7 @@ private[spark] object Utils extends Logging {
var inWord = false
var inSingleQuote = false
var inDoubleQuote = false
- var curWord = new StringBuilder
+ val curWord = new StringBuilder
def endWord() {
buf += curWord.toString
curWord.clear()
@@ -794,7 +801,7 @@ private[spark] object Utils extends Logging {
if (inWord || inDoubleQuote || inSingleQuote) {
endWord()
}
- return buf
+ buf
}
/* Calculates 'x' modulo 'mod', takes to consideration sign of x,
@@ -822,8 +829,7 @@ private[spark] object Utils extends Logging {
/** Returns a copy of the system properties that is thread-safe to iterator over. */
def getSystemProperties(): Map[String, String] = {
- return System.getProperties().clone()
- .asInstanceOf[java.util.Properties].toMap[String, String]
+ System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String]
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index fe710c58ac..fcdf848637 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.util.Random
+
class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
@@ -25,7 +27,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
def + (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) + other(i))
+ Vector(length, i => this(i) + other(i))
}
def add(other: Vector) = this + other
@@ -33,7 +35,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
def - (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) - other(i))
+ Vector(length, i => this(i) - other(i))
}
def subtract(other: Vector) = this - other
@@ -47,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += this(i) * other(i)
i += 1
}
- return ans
+ ans
}
/**
@@ -67,7 +69,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += (this(i) + plus(i)) * other(i)
i += 1
}
- return ans
+ ans
}
def += (other: Vector): Vector = {
@@ -102,7 +104,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += (this(i) - other(i)) * (this(i) - other(i))
i += 1
}
- return ans
+ ans
}
def dist(other: Vector): Double = math.sqrt(squaredDist(other))
@@ -117,13 +119,19 @@ object Vector {
def apply(length: Int, initializer: Int => Double): Vector = {
val elements: Array[Double] = Array.tabulate(length)(initializer)
- return new Vector(elements)
+ new Vector(elements)
}
def zeros(length: Int) = new Vector(new Array[Double](length))
def ones(length: Int) = Vector(length, _ => 1)
+ /**
+ * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
+ * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
+ */
+ def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble())
+
class Multiplier(num: Double) {
def * (vec: Vector) = vec * num
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
index e9907e6c85..08b31ac64f 100644
--- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -91,4 +91,4 @@ private[spark] object XORShiftRandom {
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 8bb4ee3bfa..b8c852b4ff 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
+
+import java.util.{Arrays, Comparator}
/**
* A simple open hash table optimized for the append-only use case, where keys
@@ -28,14 +30,15 @@ package org.apache.spark.util
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
-class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K,
+ V)] with Serializable {
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
- private var growThreshold = LOAD_FACTOR * capacity
+ private var growThreshold = (LOAD_FACTOR * capacity).toInt
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]
+ // Triggered by destructiveSortedIterator; the underlying data array may no longer be used
+ private var destroyed = false
+ private val destructionMessage = "Map state is invalid from destructive sorting!"
+
private val LOAD_FACTOR = 0.7
/** Get the value for a given key */
def apply(key: K): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
@@ -67,11 +75,12 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
i += 1
}
}
- return null.asInstanceOf[V]
+ null.asInstanceOf[V]
}
/** Set the value for a key */
def update(key: K, value: V): Unit = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
* for key, if any, or null otherwise. Returns the newly updated value.
*/
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Iterator method from Iterable */
- override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
- var pos = -1
-
- /** Get the next value we should return from next(), or null if we're finished iterating */
- def nextValue(): (K, V) = {
- if (pos == -1) { // Treat position -1 as looking at the null value
- if (haveNullValue) {
- return (null.asInstanceOf[K], nullValue)
+ override def iterator: Iterator[(K, V)] = {
+ assert(!destroyed, destructionMessage)
+ new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
}
- pos += 1
- }
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
}
- pos += 1
+ null
}
- null
- }
- override def hasNext: Boolean = nextValue() != null
+ override def hasNext: Boolean = nextValue() != null
- override def next(): (K, V) = {
- val value = nextValue()
- if (value == null) {
- throw new NoSuchElementException("End of iterator")
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
}
- pos += 1
- value
}
}
@@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Double the table's size and re-hash everything */
- private def growTable() {
+ protected def growTable() {
val newCapacity = capacity * 2
if (newCapacity >= (1 << 30)) {
// We can't make the table this big because we want an array of 2x
@@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
data = newData
capacity = newCapacity
mask = newMask
- growThreshold = LOAD_FACTOR * newCapacity
+ growThreshold = (LOAD_FACTOR * newCapacity).toInt
}
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
if (highBit == n) n else highBit << 1
}
+
+ /**
+ * Return an iterator of the map in sorted order. This provides a way to sort the map without
+ * using additional memory, at the expense of destroying the validity of the map.
+ */
+ def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = {
+ destroyed = true
+ // Pack KV pairs into the front of the underlying array
+ var keyIndex, newIndex = 0
+ while (keyIndex < capacity) {
+ if (data(2 * keyIndex) != null) {
+ data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1))
+ newIndex += 1
+ }
+ keyIndex += 1
+ }
+ assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
+
+ // Sort by the given ordering
+ val rawOrdering = new Comparator[AnyRef] {
+ def compare(x: AnyRef, y: AnyRef): Int = {
+ cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)])
+ }
+ }
+ Arrays.sort(data, 0, newIndex, rawOrdering)
+
+ new Iterator[(K, V)] {
+ var i = 0
+ var nullValueReady = haveNullValue
+ def hasNext: Boolean = (i < newIndex || nullValueReady)
+ def next(): (K, V) = {
+ if (nullValueReady) {
+ nullValueReady = false
+ (null.asInstanceOf[K], nullValue)
+ } else {
+ val item = data(i).asInstanceOf[(K, V)]
+ i += 1
+ item
+ }
+ }
+ }
+ }
+
+ /**
+ * Return whether the next insert will cause the map to grow
+ */
+ def atGrowThreshold: Boolean = curSize == growThreshold
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
new file mode 100644
index 0000000000..e3bcd895aa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
+
+/**
+ * An append-only map that spills sorted content to disk when there is insufficient space for it
+ * to grow.
+ *
+ * This map takes two passes over the data:
+ *
+ * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary
+ * (2) Combiners are read from disk and merged together
+ *
+ * The setting of the spill threshold faces the following trade-off: If the spill threshold is
+ * too high, the in-memory map may occupy more memory than is available, resulting in OOM.
+ * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk
+ * writes. This may lead to a performance regression compared to the normal case of using the
+ * non-spilling AppendOnlyMap.
+ *
+ * Two parameters control the memory threshold:
+ *
+ * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing
+ * these maps as a fraction of the executor's total memory. Since each concurrently running
+ * task maintains one map, the actual threshold for each map is this quantity divided by the
+ * number of running tasks.
+ *
+ * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of
+ * this threshold, in case map size estimation is not sufficiently accurate.
+ */
+
+private[spark] class ExternalAppendOnlyMap[K, V, C](
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ serializer: Serializer = SparkEnv.get.serializerManager.default,
+ diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager)
+ extends Iterable[(K, C)] with Serializable with Logging {
+
+ import ExternalAppendOnlyMap._
+
+ private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ private val spilledMaps = new ArrayBuffer[DiskMapIterator]
+ private val sparkConf = SparkEnv.get.conf
+
+ // Collective memory threshold shared across all running tasks
+ private val maxMemoryThreshold = {
+ val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3)
+ val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+
+ // Number of pairs in the in-memory map
+ private var numPairsInMemory = 0
+
+ // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
+ private val trackMemoryThreshold = 1000
+
+ // How many times we have spilled so far
+ private var spillCount = 0
+
+ private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+ private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
+ private val comparator = new KCComparator[K, C]
+ private val ser = serializer.newInstance()
+
+ /**
+ * Insert the given key and value into the map.
+ *
+ * If the underlying map is about to grow, check if the global pool of shuffle memory has
+ * enough room for this to happen. If so, allocate the memory required to grow the map;
+ * otherwise, spill the in-memory map to disk.
+ *
+ * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
+ */
+ def insert(key: K, value: V) {
+ val update: (Boolean, C) => C = (hadVal, oldVal) => {
+ if (hadVal) mergeValue(oldVal, value) else createCombiner(value)
+ }
+ if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
+ val mapSize = currentMap.estimateSize()
+ var shouldSpill = false
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+ // Atomically check whether there is sufficient memory in the global pool for
+ // this map to grow and, if possible, allocate the required amount
+ shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
+ val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
+ val availableMemory = maxMemoryThreshold -
+ (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
+
+ // Assume map growth factor is 2x
+ shouldSpill = availableMemory < mapSize * 2
+ if (!shouldSpill) {
+ shuffleMemoryMap(threadId) = mapSize * 2
+ }
+ }
+ // Do not synchronize spills
+ if (shouldSpill) {
+ spill(mapSize)
+ }
+ }
+ currentMap.changeValue(key, update)
+ numPairsInMemory += 1
+ }
+
+ /**
+ * Sort the existing contents of the in-memory map and spill them to a temporary file on disk
+ */
+ private def spill(mapSize: Long) {
+ spillCount += 1
+ logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
+ .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ val (blockId, file) = diskBlockManager.createTempBlock()
+ val writer =
+ new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites)
+ try {
+ val it = currentMap.destructiveSortedIterator(comparator)
+ while (it.hasNext) {
+ val kv = it.next()
+ writer.write(kv)
+ }
+ writer.commit()
+ } finally {
+ // Partial failures cannot be tolerated; do not revert partial writes
+ writer.close()
+ }
+ currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ spilledMaps.append(new DiskMapIterator(file))
+
+ // Reset the amount of shuffle memory used by this map in the global pool
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap(Thread.currentThread().getId) = 0
+ }
+ numPairsInMemory = 0
+ }
+
+ /**
+ * Return an iterator that merges the in-memory map with the spilled maps.
+ * If no spill has occurred, simply return the in-memory map's iterator.
+ */
+ override def iterator: Iterator[(K, C)] = {
+ if (spilledMaps.isEmpty) {
+ currentMap.iterator
+ } else {
+ new ExternalIterator()
+ }
+ }
+
+ /**
+ * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
+ */
+ private class ExternalIterator extends Iterator[(K, C)] {
+
+ // A fixed-size queue that maintains a buffer for each stream we are currently merging
+ val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
+
+ // Input streams are derived both from the in-memory map and spilled maps on disk
+ // The in-memory map is sorted in place, while the spilled maps are already in sorted order
+ val sortedMap = currentMap.destructiveSortedIterator(comparator)
+ val inputStreams = Seq(sortedMap) ++ spilledMaps
+
+ inputStreams.foreach { it =>
+ val kcPairs = getMorePairs(it)
+ mergeHeap.enqueue(StreamBuffer(it, kcPairs))
+ }
+
+ /**
+ * Fetch from the given iterator until a key of different hash is retrieved. In the
+ * event of key hash collisions, this ensures no pairs are hidden from being merged.
+ * Assume the given iterator is in sorted order.
+ */
+ def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
+ val kcPairs = new ArrayBuffer[(K, C)]
+ if (it.hasNext) {
+ var kc = it.next()
+ kcPairs += kc
+ val minHash = kc._1.hashCode()
+ while (it.hasNext && kc._1.hashCode() == minHash) {
+ kc = it.next()
+ kcPairs += kc
+ }
+ }
+ kcPairs
+ }
+
+ /**
+ * If the given buffer contains a value for the given key, merge that value into
+ * baseCombiner and remove the corresponding (K, C) pair from the buffer
+ */
+ def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
+ var i = 0
+ while (i < buffer.pairs.size) {
+ val (k, c) = buffer.pairs(i)
+ if (k == key) {
+ buffer.pairs.remove(i)
+ return mergeCombiners(baseCombiner, c)
+ }
+ i += 1
+ }
+ baseCombiner
+ }
+
+ /**
+ * Return true if there exists an input stream that still has unvisited pairs
+ */
+ override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
+
+ /**
+ * Select a key with the minimum hash, then combine all values with the same key from all input streams.
+ */
+ override def next(): (K, C) = {
+ // Select a key from the StreamBuffer that holds the lowest key hash
+ val minBuffer = mergeHeap.dequeue()
+ val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
+ if (minPairs.length == 0) {
+ // Should only happen when no other stream buffers have any pairs left
+ throw new NoSuchElementException
+ }
+ var (minKey, minCombiner) = minPairs.remove(0)
+ assert(minKey.hashCode() == minHash)
+
+ // For all other streams that may have this key (i.e. have the same minimum key hash),
+ // merge in the corresponding value (if any) from that stream
+ val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
+ while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) {
+ val newBuffer = mergeHeap.dequeue()
+ minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
+ mergedBuffers += newBuffer
+ }
+
+ // Repopulate each visited stream buffer and add it back to the merge heap
+ mergedBuffers.foreach { buffer =>
+ if (buffer.pairs.length == 0) {
+ buffer.pairs ++= getMorePairs(buffer.iterator)
+ }
+ mergeHeap.enqueue(buffer)
+ }
+
+ (minKey, minCombiner)
+ }
+
+ /**
+ * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash.
+ * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to
+ * hash collisions, it is possible for multiple keys to be "tied" for being the lowest.
+ *
+ * StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
+ */
+ case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
+ extends Comparable[StreamBuffer] {
+
+ def minKeyHash: Int = {
+ if (pairs.length > 0){
+ // pairs are already sorted by key hash
+ pairs(0)._1.hashCode()
+ } else {
+ Int.MaxValue
+ }
+ }
+
+ override def compareTo(other: StreamBuffer): Int = {
+ // minus sign because mutable.PriorityQueue dequeues the max, not the min
+ -minKeyHash.compareTo(other.minKeyHash)
+ }
+ }
+ }
+
+ /**
+ * An iterator that returns (K, C) pairs in sorted order from an on-disk map
+ */
+ private class DiskMapIterator(file: File) extends Iterator[(K, C)] {
+ val fileStream = new FileInputStream(file)
+ val bufferedStream = new FastBufferedInputStream(fileStream)
+ val deserializeStream = ser.deserializeStream(bufferedStream)
+ var nextItem: (K, C) = null
+ var eof = false
+
+ def readNextItem(): (K, C) = {
+ if (!eof) {
+ try {
+ return deserializeStream.readObject().asInstanceOf[(K, C)]
+ } catch {
+ case e: EOFException =>
+ eof = true
+ cleanup()
+ }
+ }
+ null
+ }
+
+ override def hasNext: Boolean = {
+ if (nextItem == null) {
+ nextItem = readNextItem()
+ }
+ nextItem != null
+ }
+
+ override def next(): (K, C) = {
+ val item = if (nextItem == null) readNextItem() else nextItem
+ if (item == null) {
+ throw new NoSuchElementException
+ }
+ nextItem = null
+ item
+ }
+
+ // TODO: Ensure this gets called even if the iterator isn't drained.
+ def cleanup() {
+ deserializeStream.close()
+ file.delete()
+ }
+ }
+}
+
+private[spark] object ExternalAppendOnlyMap {
+ private class KCComparator[K, C] extends Comparator[(K, C)] {
+ def compare(kc1: (K, C), kc2: (K, C)): Int = {
+ kc1._1.hashCode().compareTo(kc2._1.hashCode())
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
new file mode 100644
index 0000000000..204330dad4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.util.SizeEstimator
+import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample
+
+/**
+ * Append-only map that keeps track of its estimated size in bytes.
+ * We sample with a slow exponential back-off using the SizeEstimator to amortize the time,
+ * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds).
+ */
+private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] {
+
+ /**
+ * Controls the base of the exponential which governs the rate of sampling.
+ * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements.
+ */
+ private val SAMPLE_GROWTH_RATE = 1.1
+
+ /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */
+ private val samples = new ArrayBuffer[Sample]()
+
+ /** Total number of insertions and updates into the map since the last resetSamples(). */
+ private var numUpdates: Long = _
+
+ /** The value of 'numUpdates' at which we will take our next sample. */
+ private var nextSampleNum: Long = _
+
+ /** The average number of bytes per update between our last two samples. */
+ private var bytesPerUpdate: Double = _
+
+ resetSamples()
+
+ /** Called after the map grows in size, as this can be a dramatic change for small objects. */
+ def resetSamples() {
+ numUpdates = 1
+ nextSampleNum = 1
+ samples.clear()
+ takeSample()
+ }
+
+ override def update(key: K, value: V): Unit = {
+ super.update(key, value)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ }
+
+ override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val newValue = super.changeValue(key, updateFunc)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ newValue
+ }
+
+ /** Takes a new sample of the current map's size. */
+ def takeSample() {
+ samples += Sample(SizeEstimator.estimate(this), numUpdates)
+ // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change.
+ bytesPerUpdate = math.max(0, samples.toSeq.reverse match {
+ case latest :: previous :: tail =>
+ (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
+ case _ =>
+ 0
+ })
+ nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
+ }
+
+ override protected def growTable() {
+ super.growTable()
+ resetSamples()
+ }
+
+ /** Estimates the current size of the map in bytes. O(1) time. */
+ def estimateSize(): Long = {
+ assert(samples.nonEmpty)
+ val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
+ (samples.last.size + extrapolatedDelta).toLong
+ }
+}
+
+private object SizeTrackingAppendOnlyMap {
+ case class Sample(size: Long, numUpdates: Long)
+}
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 8dd5786da6..3ac706110e 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -53,7 +53,6 @@ object LocalSparkContext {
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
/** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index afc1beff98..930c2523ca 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -99,7 +99,6 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val hostname = "localhost"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf)
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 331fa3a642..d05bbd6ff7 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -25,8 +25,8 @@ import net.liftweb.json.JsonAST.JValue
import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
-import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState, WorkerInfo}
-import org.apache.spark.deploy.worker.ExecutorRunner
+import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo}
+import org.apache.spark.deploy.worker.{ExecutorRunner, DriverRunner}
class JsonProtocolSuite extends FunSuite {
test("writeApplicationInfo") {
@@ -50,11 +50,13 @@ class JsonProtocolSuite extends FunSuite {
}
test("writeMasterState") {
- val workers = Array[WorkerInfo](createWorkerInfo(), createWorkerInfo())
- val activeApps = Array[ApplicationInfo](createAppInfo())
+ val workers = Array(createWorkerInfo(), createWorkerInfo())
+ val activeApps = Array(createAppInfo())
val completedApps = Array[ApplicationInfo]()
+ val activeDrivers = Array(createDriverInfo())
+ val completedDrivers = Array(createDriverInfo())
val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps,
- RecoveryState.ALIVE)
+ activeDrivers, completedDrivers, RecoveryState.ALIVE)
val output = JsonProtocol.writeMasterState(stateResponse)
assertValidJson(output)
}
@@ -62,26 +64,44 @@ class JsonProtocolSuite extends FunSuite {
test("writeWorkerState") {
val executors = List[ExecutorRunner]()
val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner())
+ val drivers = List(createDriverRunner())
+ val finishedDrivers = List(createDriverRunner(), createDriverRunner())
val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors,
- finishedExecutors, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl")
+ finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl")
val output = JsonProtocol.writeWorkerState(stateResponse)
assertValidJson(output)
}
- def createAppDesc() : ApplicationDescription = {
+ def createAppDesc(): ApplicationDescription = {
val cmd = new Command("mainClass", List("arg1", "arg2"), Map())
new ApplicationDescription("name", Some(4), 1234, cmd, "sparkHome", "appUiUrl")
}
+
def createAppInfo() : ApplicationInfo = {
new ApplicationInfo(
3, "id", createAppDesc(), new Date(123456789), null, "appUriStr", Int.MaxValue)
}
- def createWorkerInfo() : WorkerInfo = {
+
+ def createDriverCommand() = new Command(
+ "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"),
+ Map(("K1", "V1"), ("K2", "V2"))
+ )
+
+ def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3,
+ false, createDriverCommand())
+
+ def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date())
+
+ def createWorkerInfo(): WorkerInfo = {
new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress")
}
- def createExecutorRunner() : ExecutorRunner = {
+ def createExecutorRunner(): ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
- new File("sparkHome"), new File("workDir"), ExecutorState.RUNNING)
+ new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING)
+ }
+ def createDriverRunner(): DriverRunner = {
+ new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(),
+ null, "akka://worker")
}
def assertValidJson(json: JValue) {
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
new file mode 100644
index 0000000000..45dbcaffae
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -0,0 +1,131 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+
+import scala.collection.JavaConversions._
+
+import org.mockito.Mockito._
+import org.mockito.Matchers._
+import org.scalatest.FunSuite
+
+import org.apache.spark.deploy.{Command, DriverDescription}
+import org.mockito.stubbing.Answer
+import org.mockito.invocation.InvocationOnMock
+
+class DriverRunnerTest extends FunSuite {
+ private def createDriverRunner() = {
+ val command = new Command("mainClass", Seq(), Map())
+ val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
+ new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription,
+ null, "akka://1.2.3.4/worker/")
+ }
+
+ private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
+ val processBuilder = mock(classOf[ProcessBuilderLike])
+ when(processBuilder.command).thenReturn(Seq("mocked", "command"))
+ val process = mock(classOf[Process])
+ when(processBuilder.start()).thenReturn(process)
+ (processBuilder, process)
+ }
+
+ test("Process succeeds instantly") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ // One failure then a successful run
+ when(process.waitFor()).thenReturn(0)
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Process failing several times and then succeeding") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ // fail, fail, fail, success
+ when(process.waitFor()).thenReturn(-1).thenReturn(-1).thenReturn(-1).thenReturn(0)
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(4)).waitFor()
+ verify(sleeper, times(3)).sleep(anyInt())
+ verify(sleeper, times(1)).sleep(1)
+ verify(sleeper, times(1)).sleep(2)
+ verify(sleeper, times(1)).sleep(4)
+ }
+
+ test("Process doesn't restart if not supervised") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ when(process.waitFor()).thenReturn(-1)
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = false)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Process doesn't restart if killed") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+ when(process.waitFor()).thenAnswer(new Answer[Int] {
+ def answer(invocation: InvocationOnMock): Int = {
+ runner.kill()
+ -1
+ }
+ })
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(process, times(1)).waitFor()
+ verify(sleeper, times(0)).sleep(anyInt())
+ }
+
+ test("Reset of backoff counter") {
+ val runner = createDriverRunner()
+
+ val sleeper = mock(classOf[Sleeper])
+ runner.setSleeper(sleeper)
+
+ val clock = mock(classOf[Clock])
+ runner.setClock(clock)
+
+ val (processBuilder, process) = createProcessBuilderAndProcess()
+
+ when(process.waitFor())
+ .thenReturn(-1) // fail 1
+ .thenReturn(-1) // fail 2
+ .thenReturn(-1) // fail 3
+ .thenReturn(-1) // fail 4
+ .thenReturn(0) // success
+ when(clock.currentTimeMillis())
+ .thenReturn(0).thenReturn(1000) // fail 1 (short)
+ .thenReturn(1000).thenReturn(2000) // fail 2 (short)
+ .thenReturn(2000).thenReturn(10000) // fail 3 (long)
+ .thenReturn(10000).thenReturn(11000) // fail 4 (short)
+ .thenReturn(11000).thenReturn(21000) // success (long)
+
+ runner.runCommandWithRetry(processBuilder, p => (), supervise = true)
+
+ verify(sleeper, times(4)).sleep(anyInt())
+ // Expected sequence of sleeps is 1,2,1,2
+ verify(sleeper, times(2)).sleep(1)
+ verify(sleeper, times(2)).sleep(2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index be93074b7b..a79ee690d3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -31,8 +31,8 @@ class ExecutorRunnerTest extends FunSuite {
sparkHome, "appUiUrl")
val appId = "12345-worker321-9876"
val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
- f("ooga"), ExecutorState.RUNNING)
+ f("ooga"), "blah", ExecutorState.RUNNING)
- assert(er.buildCommandSeq().last === appId)
+ assert(er.getCommandSeq.last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
new file mode 100644
index 0000000000..94d88d307a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala
@@ -0,0 +1,32 @@
+package org.apache.spark.deploy.worker
+
+
+import akka.testkit.TestActorRef
+import org.scalatest.FunSuite
+import akka.remote.DisassociatedEvent
+import akka.actor.{ActorSystem, AddressFromURIString, Props}
+
+class WorkerWatcherSuite extends FunSuite {
+ test("WorkerWatcher shuts down on valid disassociation") {
+ val actorSystem = ActorSystem("test")
+ val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+ val targetWorkerAddress = AddressFromURIString(targetWorkerUrl)
+ val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+ val workerWatcher = actorRef.underlyingActor
+ workerWatcher.setTesting(testing = true)
+ actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false))
+ assert(actorRef.underlyingActor.isShutDown)
+ }
+
+ test("WorkerWatcher stays alive on invalid disassociation") {
+ val actorSystem = ActorSystem("test")
+ val targetWorkerUrl = "akka://1.2.3.4/user/Worker"
+ val otherAkkaURL = "akka://4.3.2.1/user/OtherActor"
+ val otherAkkaAddress = AddressFromURIString(otherAkkaURL)
+ val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem)
+ val workerWatcher = actorRef.underlyingActor
+ workerWatcher.setTesting(testing = true)
+ actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false))
+ assert(!actorRef.underlyingActor.isShutDown)
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 7bf2020fe3..235d31709a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -64,7 +64,7 @@ class FakeTaskSetManager(
}
override def getSchedulableByName(name: String): Schedulable = {
- return null
+ null
}
override def executorLost(executorId: String, host: String): Unit = {
@@ -79,13 +79,14 @@ class FakeTaskSetManager(
{
if (tasksSuccessful + runningTasks < numTasks) {
increaseRunningTasks(1)
- return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+ Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+ } else {
+ None
}
- return None
}
override def checkSpeculatableTasks(): Boolean = {
- return true
+ true
}
def taskFinished() {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2aa259daf3..f0236ef1e9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -122,7 +122,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
locations: Seq[Seq[String]] = Nil
): MyRDD = {
val maxPartition = numPartitions - 1
- return new MyRDD(sc, dependencies) {
+ val newRDD = new MyRDD(sc, dependencies) {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getPartitions = (0 to maxPartition).map(i => new Partition {
@@ -135,6 +135,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Nil
override def toString: String = "DAGSchedulerSuiteRDD " + id
}
+ newRDD
}
/**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 5cc48ee00a..29102913c7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -42,12 +42,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
}
type MyRDD = RDD[(Int, Int)]
- def makeRdd(
- numPartitions: Int,
- dependencies: List[Dependency[_]]
- ): MyRDD = {
+ def makeRdd(numPartitions: Int, dependencies: List[Dependency[_]]): MyRDD = {
val maxPartition = numPartitions - 1
- return new MyRDD(sc, dependencies) {
+ new MyRDD(sc, dependencies) {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getPartitions = (0 to maxPartition).map(i => new Partition {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 1eec6726f4..c9f6cc5d07 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -83,7 +83,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
private val conf = new SparkConf
- val LOCALITY_WAIT = conf.get("spark.locality.wait", "3000").toLong
+ val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000)
val MAX_TASK_FAILURES = 4
test("TaskSet with no preferences") {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f60ce270c7..18aa587662 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -53,7 +53,6 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf)
this.actorSystem = actorSystem
conf.set("spark.driver.port", boundPort.toString)
- conf.set("spark.hostPort", "localhost:" + boundPort)
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
@@ -65,13 +64,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
conf.set("spark.storage.disableBlockManagerHeartBeat", "true")
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- // Set some value ...
- conf.set("spark.hostPort", Utils.localHostName() + ":" + 1111)
}
after {
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
if (store != null) {
store.stop()
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 0ed366fb70..de4871d043 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -61,8 +61,8 @@ class NonSerializable {}
object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
- var x = 5
- return withSpark(new SparkContext("local", "test")) { sc =>
+ val x = 5
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + x).reduce(_ + _)
}
@@ -76,7 +76,7 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + getX).reduce(_ + _)
}
@@ -88,7 +88,7 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + getX).reduce(_ + _)
}
@@ -103,7 +103,7 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + x).reduce(_ + _)
}
@@ -115,7 +115,7 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
var y = 1
for (i <- 1 to 4) {
@@ -134,7 +134,7 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
for (i <- 1 to 4) {
var nonSer2 = new NonSerializable
diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..93f0c6a8e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass
+import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap}
+
+class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll {
+ val NORMAL_ERROR = 0.20
+ val HIGH_ERROR = 0.30
+
+ test("fixed size insertions") {
+ testWith[Int, Long](10000, i => (i, i.toLong))
+ testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
+ testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass()))
+ }
+
+ test("variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[Int, String](10000, i => (i, randString(0, 10)))
+ testWith[Int, String](10000, i => (i, randString(0, 100)))
+ testWith[Int, String](10000, i => (i, randString(90, 100)))
+ }
+
+ test("updates") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[String, Int](10000, i => (randString(0, 10000), i))
+ }
+
+ def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ val map = new SizeTrackingAppendOnlyMap[K, V]()
+ for (i <- 0 until numElements) {
+ val (k, v) = makeElement(i)
+ map(k) = v
+ expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ val betterEstimatedSize = SizeEstimator.estimate(obj)
+ assert(betterEstimatedSize * (1 - error) < estimatedSize,
+ s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
+ assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
+ s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
+ }
+}
+
+object SizeTrackingAppendOnlyMapSuite {
+ // Speed test, for reproducibility of results.
+ // These could be highly non-deterministic in general, however.
+ // Results:
+ // AppendOnlyMap: 31 ms
+ // SizeTracker: 54 ms
+ // SizeEstimator: 1500 ms
+ def main(args: Array[String]) {
+ val numElements = 100000
+
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ }
+ }
+
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ map.estimateSize()
+ }
+ }
+
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ SizeEstimator.estimate(map)
+ }
+ }
+
+ println("Base: " + baseTimes)
+ println("SizeTracker (sampled): " + sampledTimes)
+ println("SizeEstimator (unsampled): " + unsampledTimes)
+ }
+
+ def time(f: => Unit): Long = {
+ val start = System.currentTimeMillis()
+ f
+ System.currentTimeMillis() - start
+ }
+
+ private class LargeDummyClass {
+ val arr = new Array[Int](100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
new file mode 100644
index 0000000000..7006571ef0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+/**
+ * Tests org.apache.spark.util.Vector functionality
+ */
+class VectorSuite extends FunSuite {
+
+ def verifyVector(vector: Vector, expectedLength: Int) = {
+ assert(vector.length == expectedLength)
+ assert(vector.elements.min > 0.0)
+ assert(vector.elements.max < 1.0)
+ }
+
+ test("random with default random number generator") {
+ val vector100 = Vector.random(100)
+ verifyVector(vector100, 100)
+ }
+
+ test("random with given random number generator") {
+ val vector100 = Vector.random(100, new Random(100))
+ verifyVector(vector100, 100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
index b78367b6ca..f1d7b61b31 100644
--- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -73,4 +73,4 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
}
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
index 7177919a58..f44442f1a5 100644
--- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import java.util.Comparator
class AppendOnlyMapSuite extends FunSuite {
test("initialization") {
@@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite {
assert(map("" + i) === "" + i)
}
}
+
+ test("destructive sort") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ map.update(null, "happy new year!")
+
+ try {
+ map.apply("1")
+ map.update("1", "2013")
+ map.changeValue("1", (hadValue, oldValue) => "2014")
+ map.iterator
+ } catch {
+ case e: IllegalStateException => fail()
+ }
+
+ val it = map.destructiveSortedIterator(new Comparator[(String, String)] {
+ def compare(kv1: (String, String), kv2: (String, String)): Int = {
+ val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue
+ val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue
+ x.compareTo(y)
+ }
+ })
+
+ // Should be sorted by key
+ assert(it.hasNext)
+ var previous = it.next()
+ assert(previous == (null, "happy new year!"))
+ previous = it.next()
+ assert(previous == ("1", "2014"))
+ while (it.hasNext) {
+ val kv = it.next()
+ assert(kv._1.toInt > previous._1.toInt)
+ previous = kv
+ }
+
+ // All subsequent calls to apply, update, changeValue and iterator should throw exception
+ intercept[AssertionError] { map.apply("1") }
+ intercept[AssertionError] { map.update("1", "2013") }
+ intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") }
+ intercept[AssertionError] { map.iterator }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..ef957bb0e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -0,0 +1,230 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ override def beforeEach() {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.externalSorting", "true")
+ sc = new SparkContext("local", "test", conf)
+ }
+
+ val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i)
+ val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => {
+ buffer += i
+ }
+ val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
+ (buf1, buf2) => {
+ buf1 ++= buf2
+ }
+
+ test("simple insert") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ // Single insert
+ map.insert(1, 10)
+ var it = map.iterator
+ assert(it.hasNext)
+ val kv = it.next()
+ assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10))
+ assert(!it.hasNext)
+
+ // Multiple insert
+ map.insert(2, 20)
+ map.insert(3, 30)
+ it = map.iterator
+ assert(it.hasNext)
+ assert(it.toSet == Set[(Int, ArrayBuffer[Int])](
+ (1, ArrayBuffer[Int](10)),
+ (2, ArrayBuffer[Int](20)),
+ (3, ArrayBuffer[Int](30))))
+ }
+
+ test("insert with collision") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ map.insert(1, 10)
+ map.insert(2, 20)
+ map.insert(3, 30)
+ map.insert(1, 100)
+ map.insert(2, 200)
+ map.insert(1, 1000)
+ val it = map.iterator
+ assert(it.hasNext)
+ val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](10, 100, 1000)),
+ (2, Set[Int](20, 200)),
+ (3, Set[Int](30))))
+ }
+
+ test("ordering") {
+ val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map1.insert(1, 10)
+ map1.insert(2, 20)
+ map1.insert(3, 30)
+
+ val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map2.insert(2, 20)
+ map2.insert(3, 30)
+ map2.insert(1, 10)
+
+ val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map3.insert(3, 30)
+ map3.insert(1, 10)
+ map3.insert(2, 20)
+
+ val it1 = map1.iterator
+ val it2 = map2.iterator
+ val it3 = map3.iterator
+
+ var kv1 = it1.next()
+ var kv2 = it2.next()
+ var kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ }
+
+ test("null keys and values") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map.insert(1, 5)
+ map.insert(2, 6)
+ map.insert(3, 7)
+ assert(map.size === 3)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7))
+ ))
+
+ // Null keys
+ val nullInt = null.asInstanceOf[Int]
+ map.insert(nullInt, 8)
+ assert(map.size === 4)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7)),
+ (nullInt, Seq[Int](8))
+ ))
+
+ // Null values
+ map.insert(4, nullInt)
+ map.insert(nullInt, nullInt)
+ assert(map.size === 5)
+ val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](5)),
+ (2, Set[Int](6)),
+ (3, Set[Int](7)),
+ (4, Set[Int](nullInt)),
+ (nullInt, Set[Int](nullInt, 8))
+ ))
+ }
+
+ test("simple aggregator") {
+ // reduceByKey
+ val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1))
+ val result1 = rdd.reduceByKey(_+_).collect()
+ assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5)))
+
+ // groupByKey
+ val result2 = rdd.groupByKey().collect()
+ assert(result2.toSet == Set[(Int, Seq[Int])]
+ ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1))))
+ }
+
+ test("simple cogroup") {
+ val rdd1 = sc.parallelize(1 to 4).map(i => (i, i))
+ val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i))
+ val result = rdd1.cogroup(rdd2).collect()
+
+ result.foreach { case (i, (seq1, seq2)) =>
+ i match {
+ case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4))
+ case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3))
+ case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]())
+ case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]())
+ case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]())
+ }
+ }
+ }
+
+ test("spilling") {
+ // TODO: Figure out correct memory parameters to actually induce spilling
+ // System.setProperty("spark.shuffle.buffer.mb", "1")
+ // System.setProperty("spark.shuffle.buffer.fraction", "0.05")
+
+ // reduceByKey - should spill exactly 6 times
+ val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max(_, _)).collect()
+ assert(resultA.length == 5000)
+ resultA.foreach { case(k, v) =>
+ k match {
+ case 0 => assert(v == 1)
+ case 2500 => assert(v == 5001)
+ case 4999 => assert(v == 9999)
+ case _ =>
+ }
+ }
+
+ // groupByKey - should spill exactly 11 times
+ val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey().collect()
+ assert(resultB.length == 2500)
+ resultB.foreach { case(i, seq) =>
+ i match {
+ case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
+ case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003))
+ case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999))
+ case _ =>
+ }
+ }
+
+ // cogroup - should spill exactly 7 times
+ val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i))
+ val resultC = rddC1.cogroup(rddC2).collect()
+ assert(resultC.length == 1000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900))
+ case 500 =>
+ assert(seq1.toSet == Set[Int](500))
+ assert(seq2.toSet == Set[Int]())
+ case 999 =>
+ assert(seq1.toSet == Set[Int](999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+ }
+
+ // TODO: Test memory allocation for multiple concurrently running tasks
+}