aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/Logging.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala198
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala598
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala71
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala113
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala116
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java1
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java12
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java12
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala165
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala122
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/io/CompressionCodec.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/network/ReceiverTest.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/network/SenderTest.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala127
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala110
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala130
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala427
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala)7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala)16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala)76
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala691
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala712
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala109
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala219
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala191
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala)30
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala111
-rw-r--r--core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/util/SizeEstimator.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala117
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala4
167 files changed, 3865 insertions, 2797 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 6e922a612a..5f73d234aa 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -41,7 +41,7 @@ class Accumulable[R, T] (
@transient initialValue: R,
param: AccumulableParam[R, T])
extends Serializable {
-
+
val id = Accumulators.newId
@transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
@@ -113,7 +113,7 @@ class Accumulable[R, T] (
def setValue(newValue: R) {
this.value = newValue
}
-
+
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
@@ -177,7 +177,7 @@ class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Ser
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
- val ser = new JavaSerializer().newInstance()
+ val ser = new JavaSerializer(new SparkConf(false)).newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
copy
@@ -215,7 +215,7 @@ private object Accumulators {
val originals = Map[Long, Accumulable[_, _]]()
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
var lastId: Long = 0
-
+
def newId: Long = synchronized {
lastId += 1
return lastId
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 1ad9240cfa..c6b4ac5192 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
- } else {
+ } else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index cdfc9dd54e..69a738dc44 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -46,6 +46,7 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
+ logInfo("Starting HTTP Server")
server = new Server()
val connector = new SocketConnector
connector.setMaxIdleTime(60*1000)
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 6a973ea495..4a34989e50 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -17,8 +17,8 @@
package org.apache.spark
-import org.slf4j.Logger
-import org.slf4j.LoggerFactory
+import org.apache.log4j.{LogManager, PropertyConfigurator}
+import org.slf4j.{Logger, LoggerFactory}
/**
* Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
@@ -33,6 +33,7 @@ trait Logging {
// Method to get or create the logger for this object
protected def log: Logger = {
if (log_ == null) {
+ initializeIfNecessary()
var className = this.getClass.getName
// Ignore trailing $'s in the class names for Scala objects
if (className.endsWith("$")) {
@@ -89,7 +90,39 @@ trait Logging {
log.isTraceEnabled
}
- // Method for ensuring that logging is initialized, to avoid having multiple
- // threads do it concurrently (as SLF4J initialization is not thread safe).
- protected def initLogging() { log }
+ private def initializeIfNecessary() {
+ if (!Logging.initialized) {
+ Logging.initLock.synchronized {
+ if (!Logging.initialized) {
+ initializeLogging()
+ }
+ }
+ }
+ }
+
+ private def initializeLogging() {
+ // If Log4j doesn't seem initialized, load a default properties file
+ val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4jInitialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ val classLoader = this.getClass.getClassLoader
+ Option(classLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ log.info(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
+ }
+ Logging.initialized = true
+
+ // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads
+ // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html
+ log
+ }
+}
+
+object Logging {
+ @volatile private var initialized = false
+ val initLock = new Object()
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5e465fa22c..77b8ca1cce 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,17 +21,15 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashSet
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch._
import akka.pattern.ask
-import akka.util.Duration
-
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
-
+import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
@@ -52,10 +50,10 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}
}
-private[spark] class MapOutputTracker extends Logging {
+private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
+
+ private val timeout = AkkaUtils.askTimeout(conf)
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
@@ -67,14 +65,14 @@ private[spark] class MapOutputTracker extends Logging {
protected val epochLock = new java.lang.Object
private val metadataCleaner =
- new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
+ Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with MapOutputTracker", e)
@@ -117,11 +115,11 @@ private[spark] class MapOutputTracker extends Logging {
fetching += shuffleId
}
}
-
+
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val hostPort = Utils.localHostPort()
+ val hostPort = Utils.localHostPort(conf)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
@@ -144,7 +142,7 @@ private[spark] class MapOutputTracker extends Logging {
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ }
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
@@ -184,7 +182,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
-private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+private[spark] class MapOutputTrackerMaster(conf: SparkConf)
+ extends MapOutputTracker(conf) {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
private var cacheEpoch = epoch
@@ -244,12 +243,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses(shuffleId)
+ statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
- // out a snapshot of the locations as "locs"; let's serialize and return that
+ // out a snapshot of the locations as "statuses"; let's serialize and return that
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
@@ -274,6 +273,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}
+
+ def has(shuffleId: Int): Boolean = {
+ cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
+ }
}
private[spark] object MapOutputTracker {
@@ -308,7 +311,7 @@ private[spark] object MapOutputTracker {
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
- status =>
+ status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 0e2c987a59..31b0773bfe 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -17,8 +17,10 @@
package org.apache.spark
-import org.apache.spark.util.Utils
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -50,7 +52,7 @@ object Partitioner {
for (r <- bySize if r.partitioner != None) {
return r.partitioner.get
}
- if (System.getProperty("spark.default.parallelism") != null) {
+ if (rdd.context.conf.contains("spark.default.parallelism")) {
return new HashPartitioner(rdd.context.defaultParallelism)
} else {
return new HashPartitioner(bySize.head.partitions.size)
@@ -72,7 +74,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
-
+
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
@@ -85,10 +87,10 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
* Determines the ranges by sampling the RDD passed in.
*/
-class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
+class RangePartitioner[K <% Ordered[K]: ClassTag, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
- private val ascending: Boolean = true)
+ private val ascending: Boolean = true)
extends Partitioner {
// An array of upper bounds for the first (partitions - 1) partitions
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
new file mode 100644
index 0000000000..2de32231e8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -0,0 +1,198 @@
+package org.apache.spark
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.HashMap
+
+import com.typesafe.config.ConfigFactory
+
+/**
+ * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
+ *
+ * Most of the time, you would create a SparkConf object with `new SparkConf()`, which will load
+ * values from both the `spark.*` Java system properties and any `spark.conf` on your application's
+ * classpath (if it has one). In this case, system properties take priority over `spark.conf`, and
+ * any parameters you set directly on the `SparkConf` object take priority over both of those.
+ *
+ * For unit tests, you can also call `new SparkConf(false)` to skip loading external settings and
+ * get the same configuration no matter what is on the classpath.
+ *
+ * All setter methods in this class support chaining. For example, you can write
+ * `new SparkConf().setMaster("local").setAppName("My app")`.
+ *
+ * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified
+ * by the user. Spark does not support modifying the configuration at runtime.
+ *
+ * @param loadDefaults whether to load values from the system properties and classpath
+ */
+class SparkConf(loadDefaults: Boolean) extends Serializable with Cloneable with Logging {
+
+ /** Create a SparkConf that loads defaults from system properties and the classpath */
+ def this() = this(true)
+
+ private val settings = new HashMap[String, String]()
+
+ if (loadDefaults) {
+ ConfigFactory.invalidateCaches()
+ val typesafeConfig = ConfigFactory.systemProperties()
+ .withFallback(ConfigFactory.parseResources("spark.conf"))
+ for (e <- typesafeConfig.entrySet().asScala if e.getKey.startsWith("spark.")) {
+ settings(e.getKey) = e.getValue.unwrapped.toString
+ }
+ }
+
+ /** Set a configuration variable. */
+ def set(key: String, value: String): SparkConf = {
+ if (key == null) {
+ throw new NullPointerException("null key")
+ }
+ if (value == null) {
+ throw new NullPointerException("null value")
+ }
+ settings(key) = value
+ this
+ }
+
+ /**
+ * The master URL to connect to, such as "local" to run locally with one thread, "local[4]" to
+ * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster.
+ */
+ def setMaster(master: String): SparkConf = {
+ set("spark.master", master)
+ }
+
+ /** Set a name for your application. Shown in the Spark web UI. */
+ def setAppName(name: String): SparkConf = {
+ set("spark.app.name", name)
+ }
+
+ /** Set JAR files to distribute to the cluster. */
+ def setJars(jars: Seq[String]): SparkConf = {
+ for (jar <- jars if (jar == null)) logWarning("null jar passed to SparkContext constructor")
+ set("spark.jars", jars.filter(_ != null).mkString(","))
+ }
+
+ /** Set JAR files to distribute to the cluster. (Java-friendly version.) */
+ def setJars(jars: Array[String]): SparkConf = {
+ setJars(jars.toSeq)
+ }
+
+ /**
+ * Set an environment variable to be used when launching executors for this application.
+ * These variables are stored as properties of the form spark.executorEnv.VAR_NAME
+ * (for example spark.executorEnv.PATH) but this method makes them easier to set.
+ */
+ def setExecutorEnv(variable: String, value: String): SparkConf = {
+ set("spark.executorEnv." + variable, value)
+ }
+
+ /**
+ * Set multiple environment variables to be used when launching executors.
+ * These variables are stored as properties of the form spark.executorEnv.VAR_NAME
+ * (for example spark.executorEnv.PATH) but this method makes them easier to set.
+ */
+ def setExecutorEnv(variables: Seq[(String, String)]): SparkConf = {
+ for ((k, v) <- variables) {
+ setExecutorEnv(k, v)
+ }
+ this
+ }
+
+ /**
+ * Set multiple environment variables to be used when launching executors.
+ * (Java-friendly version.)
+ */
+ def setExecutorEnv(variables: Array[(String, String)]): SparkConf = {
+ setExecutorEnv(variables.toSeq)
+ }
+
+ /**
+ * Set the location where Spark is installed on worker nodes.
+ */
+ def setSparkHome(home: String): SparkConf = {
+ set("spark.home", home)
+ }
+
+ /** Set multiple parameters together */
+ def setAll(settings: Traversable[(String, String)]) = {
+ this.settings ++= settings
+ this
+ }
+
+ /** Set a parameter if it isn't already configured */
+ def setIfMissing(key: String, value: String): SparkConf = {
+ if (!settings.contains(key)) {
+ settings(key) = value
+ }
+ this
+ }
+
+ /** Remove a parameter from the configuration */
+ def remove(key: String): SparkConf = {
+ settings.remove(key)
+ this
+ }
+
+ /** Get a parameter; throws a NoSuchElementException if it's not set */
+ def get(key: String): String = {
+ settings.getOrElse(key, throw new NoSuchElementException(key))
+ }
+
+ /** Get a parameter, falling back to a default if not set */
+ def get(key: String, defaultValue: String): String = {
+ settings.getOrElse(key, defaultValue)
+ }
+
+ /** Get a parameter as an Option */
+ def getOption(key: String): Option[String] = {
+ settings.get(key)
+ }
+
+ /** Get all parameters as a list of pairs */
+ def getAll: Array[(String, String)] = settings.clone().toArray
+
+ /** Get a parameter as an integer, falling back to a default if not set */
+ def getInt(key: String, defaultValue: Int): Int = {
+ getOption(key).map(_.toInt).getOrElse(defaultValue)
+ }
+
+ /** Get a parameter as a long, falling back to a default if not set */
+ def getLong(key: String, defaultValue: Long): Long = {
+ getOption(key).map(_.toLong).getOrElse(defaultValue)
+ }
+
+ /** Get a parameter as a double, falling back to a default if not set */
+ def getDouble(key: String, defaultValue: Double): Double = {
+ getOption(key).map(_.toDouble).getOrElse(defaultValue)
+ }
+
+ /** Get a parameter as a boolean, falling back to a default if not set */
+ def getBoolean(key: String, defaultValue: Boolean): Boolean = {
+ getOption(key).map(_.toBoolean).getOrElse(defaultValue)
+ }
+
+ /** Get all executor environment variables set on this SparkConf */
+ def getExecutorEnv: Seq[(String, String)] = {
+ val prefix = "spark.executorEnv."
+ getAll.filter{case (k, v) => k.startsWith(prefix)}
+ .map{case (k, v) => (k.substring(prefix.length), v)}
+ }
+
+ /** Get all akka conf variables set on this SparkConf */
+ def getAkkaConf: Seq[(String, String)] = getAll.filter {case (k, v) => k.startsWith("akka.")}
+
+ /** Does the configuration contain a given parameter? */
+ def contains(key: String): Boolean = settings.contains(key)
+
+ /** Copy this object */
+ override def clone: SparkConf = {
+ new SparkConf(false).setAll(settings)
+ }
+
+ /**
+ * Return a string listing all keys and values, one per line. This is useful to print the
+ * configuration out for debugging.
+ */
+ def toDebugString: String = {
+ settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 42b2985b50..0e47f4e442 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -19,35 +19,23 @@ package org.apache.spark
import java.io._
import java.net.URI
-import java.util.Properties
+import java.util.{UUID, Properties}
import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.Map
+import scala.collection.{Map, Set}
import scala.collection.generic.Growable
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.ArrayWritable
-import org.apache.hadoop.io.BooleanWritable
-import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.DoubleWritable
-import org.apache.hadoop.io.FloatWritable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapred.FileInputFormat
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.mapred.TextInputFormat
-import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
+FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
+import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat,
+TextInputFormat}
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
-
import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
@@ -55,59 +43,106 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
- SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
+ SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
-import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
- TimeStampedHashMap, Utils}
+import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType,
+ClosureCleaner}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param appName A name for your application, to display on the cluster web UI.
- * @param sparkHome Location where Spark is installed on cluster nodes.
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
- * @param environment Environment variables to set on worker nodes.
+ * @param config a Spark Config object describing the application configuration. Any settings in
+ * this config overrides the default configs as well as system properties.
+ * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Can
+ * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]]
+ * from a list of input files or InputFormats for the application.
*/
class SparkContext(
- val master: String,
- val appName: String,
- val sparkHome: String = null,
- val jars: Seq[String] = Nil,
- val environment: Map[String, String] = Map(),
- // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc)
- // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set
- // of data-local splits on host
- val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
- scala.collection.immutable.Map())
+ config: SparkConf,
+ // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc)
+ // too. This is typically generated from InputFormatInfo.computePreferredLocations. It contains
+ // a map from hostname to a list of input format splits on the host.
+ val preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map())
extends Logging {
- // Ensure logging is initialized before we spawn any threads
- initLogging()
+ /**
+ * Alternative constructor that allows setting common Spark properties directly
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
+ * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters
+ */
+ def this(master: String, appName: String, conf: SparkConf) =
+ this(SparkContext.updatedConf(conf, master, appName))
- // Set Spark driver host and port system properties
- if (System.getProperty("spark.driver.host") == null) {
- System.setProperty("spark.driver.host", Utils.localHostName())
+ /**
+ * Alternative constructor that allows setting common Spark properties directly
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI.
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param environment Environment variables to set on worker nodes.
+ */
+ def this(
+ master: String,
+ appName: String,
+ sparkHome: String = null,
+ jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map(),
+ preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) =
+ {
+ this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment),
+ preferredNodeLocationData)
}
- if (System.getProperty("spark.driver.port") == null) {
- System.setProperty("spark.driver.port", "0")
+
+ private[spark] val conf = config.clone()
+
+ /**
+ * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be
+ * changed at runtime.
+ */
+ def getConf: SparkConf = conf.clone()
+
+ if (!conf.contains("spark.master")) {
+ throw new SparkException("A master URL must be set in your configuration")
+ }
+ if (!conf.contains("spark.app.name")) {
+ throw new SparkException("An application must be set in your configuration")
+ }
+
+ if (conf.get("spark.logConf", "false").toBoolean) {
+ logInfo("Spark configuration:\n" + conf.toDebugString)
}
+ // Set Spark driver host and port system properties
+ conf.setIfMissing("spark.driver.host", Utils.localHostName())
+ conf.setIfMissing("spark.driver.port", "0")
+
+ val jars: Seq[String] = if (conf.contains("spark.jars")) {
+ conf.get("spark.jars").split(",").filter(_.size != 0)
+ } else {
+ null
+ }
+
+ val master = conf.get("spark.master")
+ val appName = conf.get("spark.app.name")
+
val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
- private[spark] val env = SparkEnv.createFromSystemProperties(
+ private[spark] val env = SparkEnv.create(
+ conf,
"<driver>",
- System.getProperty("spark.driver.host"),
- System.getProperty("spark.driver.port").toInt,
- true,
- isLocal)
+ conf.get("spark.driver.host"),
+ conf.get("spark.driver.port").toInt,
+ isDriver = true,
+ isLocal = isLocal)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -116,7 +151,8 @@ class SparkContext(
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
- private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
+ private[spark] val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
// Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
@@ -126,23 +162,30 @@ class SparkContext(
// Add each JAR given through the constructor
if (jars != null) {
- jars.foreach { addJar(_) }
+ jars.foreach(addJar)
}
+ private[spark] val executorMemory = conf.getOption("spark.executor.memory")
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
- for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
- val value = System.getenv(key)
- if (value != null) {
- executorEnvs(key) = value
- }
+ for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS");
+ value <- Option(System.getenv(key))) {
+ executorEnvs(key) = value
}
- // Since memory can be set with a system property too, use that
- executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
- if (environment != null) {
- executorEnvs ++= environment
+ // Convert java options to env vars as a work around
+ // since we can't set env vars directly in sbt.
+ for { (envKey, propKey) <- Seq(("SPARK_HOME", "spark.home"), ("SPARK_TESTING", "spark.testing"))
+ value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
+ executorEnvs(envKey) = value
}
+ // Since memory can be set with a system property too, use that
+ executorEnvs("SPARK_MEM") = executorMemory + "m"
+ executorEnvs ++= conf.getExecutorEnv
// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
@@ -153,122 +196,35 @@ class SparkContext(
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler: TaskScheduler = {
- // Regular expression used for local[N] master format
- val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
- // Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
- // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
- val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
- // Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """spark://(.*)""".r
- // Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """mesos://(.*)""".r
- // Regular expression for connection to Simr cluster
- val SIMR_REGEX = """simr://(.*)""".r
-
- master match {
- case "local" =>
- new LocalScheduler(1, 0, this)
-
- case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0, this)
-
- case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt, this)
-
- case SPARK_REGEX(sparkUrl) =>
- val scheduler = new ClusterScheduler(this)
- val masterUrls = sparkUrl.split(",").map("spark://" + _)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- scheduler
-
- case SIMR_REGEX(simrUrl) =>
- val scheduler = new ClusterScheduler(this)
- val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
- scheduler.initialize(backend)
- scheduler
-
- case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
- val memoryPerSlaveInt = memoryPerSlave.toInt
- if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
- throw new SparkException(
- "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
- memoryPerSlaveInt, SparkContext.executorMemoryRequested))
- }
-
- val scheduler = new ClusterScheduler(this)
- val localCluster = new LocalSparkCluster(
- numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val masterUrls = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
- localCluster.stop()
- }
- scheduler
-
- case "yarn-standalone" =>
- val scheduler = try {
- val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
- val cons = clazz.getConstructor(classOf[SparkContext])
- cons.newInstance(this).asInstanceOf[ClusterScheduler]
- } catch {
- // TODO: Enumerate the exact reasons why it can fail
- // But irrespective of it, it means we cannot proceed !
- case th: Throwable => {
- throw new SparkException("YARN mode not available ?", th)
- }
- }
- val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
- scheduler.initialize(backend)
- scheduler
-
- case MESOS_REGEX(mesosUrl) =>
- MesosNativeLibrary.load()
- val scheduler = new ClusterScheduler(this)
- val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
- } else {
- new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
- }
- scheduler.initialize(backend)
- scheduler
-
- case _ =>
- throw new SparkException("Could not parse Master URL: '" + master + "'")
- }
- }
+ private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName)
taskScheduler.start()
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
ui.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
- val conf = SparkHadoopUtil.get.newConfiguration()
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
- conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
- conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- Utils.getSystemProperties.foreach { case (key, value) =>
+ conf.getAll.foreach { case (key, value) =>
if (key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), value)
+ hadoopConf.set(key.substring("spark.hadoop.".length), value)
}
}
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
- conf
+ val bufferSize = conf.get("spark.buffer.size", "65536")
+ hadoopConf.set("io.file.buffer.size", bufferSize)
+ hadoopConf
}
private[spark] var checkpointDir: Option[String] = None
@@ -278,7 +234,7 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
- private[spark] def getLocalProperties(): Properties = localProperties.get()
+ private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
localProperties.set(props)
@@ -354,19 +310,19 @@ class SparkContext(
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
- def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
/** Distribute a local Scala collection to form an RDD, with one or more
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
- def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -419,7 +375,7 @@ class SparkContext(
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -427,17 +383,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
: RDD[(K, V)] = {
hadoopFile(path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]],
minSplits)
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -445,17 +401,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
newAPIHadoopFile(
path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]])
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]])
}
/**
@@ -513,11 +469,11 @@ class SparkContext(
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
- * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
+ * for the appropriate type. In addition, we pass the converter a ClassTag of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
- (implicit km: ClassManifest[K], vm: ClassManifest[V],
+ (implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
@@ -536,7 +492,7 @@ class SparkContext(
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
- def objectFile[T: ClassManifest](
+ def objectFile[T: ClassTag](
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
@@ -545,17 +501,17 @@ class SparkContext(
}
- protected[spark] def checkpointFile[T: ClassManifest](
+ protected[spark] def checkpointFile[T: ClassTag](
path: String
): RDD[T] = {
new CheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
- def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
/** Build the union of a list of RDDs passed as variable-length arguments. */
- def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
+ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
@@ -608,10 +564,8 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case a job is executed locally.
- // Jobs that run through LocalScheduler will already fetch the required dependencies,
- // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
+ // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -781,15 +735,27 @@ class SparkContext(
* (in that order of preference). If neither of these is set, return None.
*/
private[spark] def getSparkHome(): Option[String] = {
- if (sparkHome != null) {
- Some(sparkHome)
- } else if (System.getProperty("spark.home") != null) {
- Some(System.getProperty("spark.home"))
- } else if (System.getenv("SPARK_HOME") != null) {
- Some(System.getenv("SPARK_HOME"))
- } else {
- None
- }
+ conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME")))
+ }
+
+ /**
+ * Support function for API backtraces.
+ */
+ def setCallSite(site: String) {
+ setLocalProperty("externalCallSite", site)
+ }
+
+ /**
+ * Support function for API backtraces.
+ */
+ def clearCallSite() {
+ setLocalProperty("externalCallSite", null)
+ }
+
+ private[spark] def getCallSite(): String = {
+ val callSite = getLocalProperty("externalCallSite")
+ if (callSite == null) return Utils.formatSparkCallSite
+ callSite
}
/**
@@ -798,13 +764,13 @@ class SparkContext(
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
@@ -819,7 +785,7 @@ class SparkContext(
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -834,7 +800,7 @@ class SparkContext(
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
@@ -846,21 +812,21 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
@@ -871,7 +837,7 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
@@ -888,7 +854,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
@@ -908,7 +874,7 @@ class SparkContext(
resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
@@ -944,22 +910,15 @@ class SparkContext(
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
- * be a HDFS path if running on a cluster. If the directory does not exist, it will
- * be created. If the directory exists and useExisting is set to true, then the
- * exisiting directory will be used. Otherwise an exception will be thrown to
- * prevent accidental overriding of checkpoint files in the existing directory.
+ * be a HDFS path if running on a cluster.
*/
- def setCheckpointDir(dir: String, useExisting: Boolean = false) {
- val path = new Path(dir)
- val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
- if (!useExisting) {
- if (fs.exists(path)) {
- throw new Exception("Checkpoint directory '" + path + "' already exists.")
- } else {
- fs.mkdirs(path)
- }
+ def setCheckpointDir(directory: String) {
+ checkpointDir = Option(directory).map { dir =>
+ val path = new Path(dir, UUID.randomUUID().toString)
+ val fs = path.getFileSystem(hadoopConfiguration)
+ fs.mkdirs(path)
+ fs.getFileStatus(path).getPath().toString
}
- checkpointDir = Some(dir)
}
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
@@ -1017,16 +976,16 @@ object SparkContext {
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
- implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
- implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions[K, V, (K, V)](rdd)
@@ -1051,16 +1010,16 @@ object SparkContext {
implicit def stringToText(s: String) = new Text(s)
- private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
+ private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
- new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
+ new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
// Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = {
- val wClass = classManifest[W].erasure.asInstanceOf[Class[W]]
+ private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = {
+ val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
@@ -1079,11 +1038,11 @@ object SparkContext {
implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
- new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
+ new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
- * their JARs to SparkContext
+ * their JARs to SparkContext.
*/
def jarOfClass(cls: Class[_]): Seq[String] = {
val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
@@ -1100,28 +1059,181 @@ object SparkContext {
}
}
- /** Find the JAR that contains the class of a particular object */
+ /**
+ * Find the JAR that contains the class of a particular object, to make it easy for users
+ * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in
+ * your driver program.
+ */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
- /** Get the amount of memory per executor requested through system properties or SPARK_MEM */
- private[spark] val executorMemoryRequested = {
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- Option(System.getProperty("spark.executor.memory"))
- .orElse(Option(System.getenv("SPARK_MEM")))
- .map(Utils.memoryStringToMb)
- .getOrElse(512)
+ /**
+ * Creates a modified version of a SparkConf with the parameters that can be passed separately
+ * to SparkContext, to make it easier to write SparkContext's constructors. This ignores
+ * parameters that are passed as the default value of null, instead of throwing an exception
+ * like SparkConf would.
+ */
+ private def updatedConf(
+ conf: SparkConf,
+ master: String,
+ appName: String,
+ sparkHome: String = null,
+ jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map()): SparkConf =
+ {
+ val res = conf.clone()
+ res.setMaster(master)
+ res.setAppName(appName)
+ if (sparkHome != null) {
+ res.setSparkHome(sparkHome)
+ }
+ if (!jars.isEmpty) {
+ res.setJars(jars)
+ }
+ res.setExecutorEnv(environment.toSeq)
+ res
+ }
+
+ /** Creates a task scheduler based on a given master URL. Extracted for testing. */
+ private def createTaskScheduler(sc: SparkContext, master: String, appName: String)
+ : TaskScheduler =
+ {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ // Regular expression for local[N, maxRetries], used in tests with failing tasks
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+ // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
+ // Regular expression for connecting to Spark deploy clusters
+ val SPARK_REGEX = """spark://(.*)""".r
+ // Regular expression for connection to Mesos cluster by mesos:// or zk:// url
+ val MESOS_REGEX = """(mesos|zk)://.*""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
+
+ // When running locally, don't try to re-execute tasks on failure.
+ val MAX_LOCAL_TASK_FAILURES = 1
+
+ master match {
+ case "local" =>
+ val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+ val backend = new LocalBackend(scheduler, 1)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_N_REGEX(threads) =>
+ val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+ val backend = new LocalBackend(scheduler, threads.toInt)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+ val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
+ val backend = new LocalBackend(scheduler, threads.toInt)
+ scheduler.initialize(backend)
+ scheduler
+
+ case SPARK_REGEX(sparkUrl) =>
+ val scheduler = new TaskSchedulerImpl(sc)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ if (sc.executorMemory > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, sc.executorMemory))
+ }
+
+ val scheduler = new TaskSchedulerImpl(sc)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
+ localCluster.stop()
+ }
+ scheduler
+
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
+ case "yarn-client" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
+
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ val backend = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
+ cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ scheduler.initialize(backend)
+ scheduler
+
+ case mesosUrl @ MESOS_REGEX(_) =>
+ MesosNativeLibrary.load()
+ val scheduler = new TaskSchedulerImpl(sc)
+ val coarseGrained = sc.conf.get("spark.mesos.coarse", "false").toBoolean
+ val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
+ val backend = if (coarseGrained) {
+ new CoarseMesosSchedulerBackend(scheduler, sc, url, appName)
+ } else {
+ new MesosSchedulerBackend(scheduler, sc, url, appName)
+ }
+ scheduler.initialize(backend)
+ scheduler
+
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new TaskSchedulerImpl(sc)
+ val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
+ scheduler.initialize(backend)
+ scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
+ }
}
}
/**
* A class encapsulating how to convert some type T to Writable. It stores both the Writable class
* corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
- * The getter for the writable class takes a ClassManifest[T] in case this is a generic object
+ * The getter for the writable class takes a ClassTag[T] in case this is a generic object
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
private[spark] class WritableConverter[T](
- val writableClass: ClassManifest[T] => Class[_ <: Writable],
+ val writableClass: ClassTag[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ff2df8fb6a..2e36ccb9a0 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -17,11 +17,10 @@
package org.apache.spark
-import collection.mutable
-import serializer.Serializer
+import scala.collection.mutable
+import scala.concurrent.Await
-import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
-import akka.remote.RemoteActorRefProvider
+import akka.actor._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
@@ -40,7 +39,7 @@ import com.google.common.collect.MapMaker
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
-class SparkEnv (
+class SparkEnv private[spark] (
val executorId: String,
val actorSystem: ActorSystem,
val serializerManager: SerializerManager,
@@ -54,7 +53,8 @@ class SparkEnv (
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
- val metricsSystem: MetricsSystem) {
+ val metricsSystem: MetricsSystem,
+ val conf: SparkConf) {
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
@@ -62,7 +62,7 @@ class SparkEnv (
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
- def stop() {
+ private[spark] def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
@@ -74,9 +74,11 @@ class SparkEnv (
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
- actorSystem.awaitTermination()
+ // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
+ //actorSystem.awaitTermination()
}
+ private[spark]
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
val key = (pythonExec, envVars)
@@ -105,33 +107,35 @@ object SparkEnv extends Logging {
/**
* Returns the ThreadLocal SparkEnv.
*/
- def getThreadLocal : SparkEnv = {
+ def getThreadLocal: SparkEnv = {
env.get()
}
- def createFromSystemProperties(
+ private[spark] def create(
+ conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
+ conf = conf)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
if (isDriver && port == 0) {
- System.setProperty("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", boundPort.toString)
}
// set only if unset until now.
- if (System.getProperty("spark.hostPort", null) == null) {
+ if (!conf.contains("spark.hostPort")) {
if (!isDriver){
// unexpected
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
}
Utils.checkHost(hostname)
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ conf.set("spark.hostPort", hostname + ":" + boundPort)
}
val classLoader = Thread.currentThread.getContextClassLoader
@@ -139,49 +143,51 @@ object SparkEnv extends Logging {
// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
- val name = System.getProperty(propertyName, defaultClassName)
+ val name = conf.get(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
- System.getProperty("spark.serializer", "org.apache.spark.serializer.JavaSerializer"))
+ conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
val closureSerializer = serializerManager.get(
- System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"))
+ conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
+ conf)
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
- val driverHost: String = System.getProperty("spark.driver.host", "localhost")
- val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+ val driverHost: String = conf.get("spark.driver.host", "localhost")
+ val driverPort: Int = conf.get("spark.driver.port", "7077").toInt
Utils.checkHost(driverHost, "Expected hostname")
- val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
- logInfo("Connecting to " + name + ": " + url)
- actorSystem.actorFor(url)
+ val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ logInfo(s"Connecting to $name: $url")
+ Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
}
}
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
- new BlockManagerMasterActor(isLocal)))
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
+ new BlockManagerMasterActor(isLocal, conf)), conf)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver)
+ val broadcastManager = new BroadcastManager(isDriver, conf)
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster()
+ new MapOutputTrackerMaster(conf)
} else {
- new MapOutputTracker()
+ new MapOutputTracker(conf)
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
@@ -192,12 +198,12 @@ object SparkEnv extends Logging {
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
- System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver")
+ MetricsSystem.createMetricsSystem("driver", conf)
} else {
- MetricsSystem.createMetricsSystem("executor")
+ MetricsSystem.createMetricsSystem("executor", conf)
}
metricsSystem.start()
@@ -211,7 +217,7 @@ object SparkEnv extends Logging {
}
// Warn about deprecated spark.cache.class property
- if (System.getProperty("spark.cache.class") != null) {
+ if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
"levels using the RDD.persist() method instead.")
}
@@ -230,6 +236,7 @@ object SparkEnv extends Logging {
connectionManager,
httpFileServer,
sparkFilesDir,
- metricsSystem)
+ metricsSystem,
+ conf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 103a1c2051..618d95015f 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -127,10 +127,6 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
cmtr.commitJob(getJobContext())
}
- def cleanup() {
- getOutputCommitter().cleanupJob(getJobContext())
- }
-
// ********* Private Functions *********
private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = {
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index c1e5e04b31..faf6dcd618 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure(
private[spark] case object TaskResultLost extends TaskEndReason
private[spark] case object TaskKilled extends TaskEndReason
-
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala
index 19ce8369d9..0bf1e4a5e2 100644
--- a/core/src/main/scala/org/apache/spark/TaskState.scala
+++ b/core/src/main/scala/org/apache/spark/TaskState.scala
@@ -19,8 +19,7 @@ package org.apache.spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
-private[spark] object TaskState
- extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
+private[spark] object TaskState extends Enumeration {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 043cb183ba..da30cf619a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -17,18 +17,23 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.util.StatCounter
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.storage.StorageLevel
+
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
- override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
+ override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]]
override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
@@ -42,7 +47,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -106,7 +111,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -182,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 2142fd7327..55c87450ac 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -22,6 +22,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -43,13 +44,13 @@ import org.apache.spark.rdd.OrderedRDDFunctions
import org.apache.spark.storage.StorageLevel
-class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
- implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
+class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K],
+ implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
- override val classManifest: ClassManifest[(K, V)] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ override val classTag: ClassTag[(K, V)] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
import JavaPairRDD._
@@ -58,7 +59,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -138,14 +139,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def first(): (K, V) = rdd.first()
// Pair RDD functions
-
+
/**
- * Generic function to combine the elements for each key using a custom set of aggregation
- * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
- * "combined type" C * Note that V and C can be different -- for example, one might group an
- * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
- *
+ *
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
@@ -157,8 +158,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
@@ -195,14 +195,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
@@ -258,7 +258,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -315,15 +315,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
- /**
+ /**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
@@ -414,8 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.mapValues(f))
}
@@ -426,8 +424,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.flatMapValues(fn))
}
@@ -592,6 +589,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
}
/**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
+ class KeyOrdering(val a: K) extends Ordered[K] {
+ override def compare(b: K) = comp.compare(a, b)
+ }
+ implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions))
+ }
+
+ /**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
@@ -600,25 +611,61 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Return an RDD with the values of each tuple.
*/
def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
+
+ /**
+ * Return approximate number of distinct values for each key in this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. Uses the provided
+ * Partitioner to partition the output RDD.
+ */
+ def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): JavaRDD[(K, Long)] = {
+ rdd.countApproxDistinctByKey(relativeSD, partitioner)
+ }
+
+ /**
+ * Return approximate number of distinct values for each key this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. The default value of
+ * relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
+ * level.
+ */
+ def countApproxDistinctByKey(relativeSD: Double = 0.05): JavaRDD[(K, Long)] = {
+ rdd.countApproxDistinctByKey(relativeSD)
+ }
+
+
+ /**
+ * Return approximate number of distinct values for each key in this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. HashPartitions the
+ * output RDD into numPartitions.
+ *
+ */
+ def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaRDD[(K, Long)] = {
+ rdd.countApproxDistinctByKey(relativeSD, numPartitions)
+ }
}
object JavaPairRDD {
- def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[T]): RDD[(K, JList[T])] =
+ def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[T]): RDD[(K, JList[T])] =
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
- def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
- Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+ def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd)
+ .mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
- Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+ Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1],
JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
(x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
seqAsJavaList(x._2),
seqAsJavaList(x._3)))
- def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
+ def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
@@ -626,10 +673,8 @@ object JavaPairRDD {
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
- implicit val cmk: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val cmv: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
new JavaPairRDD[K, V](rdd.rdd)
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 3b359a8fd6..037cd1c774 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,12 +17,14 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
-class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
+class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
@@ -123,12 +125,13 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
+
+ override def toString = rdd.toString
}
object JavaRDD {
- implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
+ implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 7a3568c5ef..924d8af060 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -20,6 +20,7 @@ package org.apache.spark.api.java
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -35,7 +36,7 @@ import org.apache.spark.storage.StorageLevel
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
- implicit val classManifest: ClassManifest[T]
+ implicit val classTag: ClassTag[T]
def rdd: RDD[T]
@@ -71,7 +72,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[R: ClassManifest](
+ def mapPartitionsWithIndex[R: ClassTag](
f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
@@ -87,7 +88,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
@@ -118,7 +119,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -158,18 +159,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
- JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
- other.classManifest)
+ JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, other.classTag)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
@@ -178,10 +177,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
@@ -209,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* a map on the other).
*/
def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = {
- JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
+ JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}
/**
@@ -224,7 +222,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
- rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType())
+ rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType())
}
// Actions (launch a job to return a value to the user program)
@@ -247,6 +245,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
+ * Return an array that contains all of the elements in a specific partition of this RDD.
+ */
+ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
+ // This is useful for implementing `take` from other language frontends
+ // like Python where the data is serialized.
+ import scala.collection.JavaConversions._
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ res.map(x => new java.util.ArrayList(x.toSeq)).toArray
+ }
+
+ /**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
@@ -356,7 +365,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
- implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
@@ -435,4 +444,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
takeOrdered(num, comp)
}
+
+ /**
+ * Return approximate number of distinct elements in the RDD.
+ *
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. The default value of
+ * relativeSD is 0.05.
+ */
+ def countApproxDistinct(relativeSD: Double = 0.05): Long = rdd.countApproxDistinct(relativeSD)
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 8869e072bf..e93b10fd7e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,6 +21,7 @@ import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
@@ -28,17 +29,22 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import com.google.common.base.Optional
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext}
+import org.apache.spark._
import org.apache.spark.SparkContext.IntAccumulatorParam
import org.apache.spark.SparkContext.DoubleAccumulatorParam
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
+import scala.Tuple2
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and
* works with Java collections instead of Scala ones.
*/
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
+ /**
+ * @param conf a [[org.apache.spark.SparkConf]] object specifying Spark parameters
+ */
+ def this(conf: SparkConf) = this(new SparkContext(conf))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
@@ -49,6 +55,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
+ * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters
+ */
+ def this(master: String, appName: String, conf: SparkConf) =
+ this(conf.setMaster(master).setAppName(appName))
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jarFile JAR file to send to the cluster. This can be a path on the local file system
* or an HDFS, HTTP, HTTPS, or FTP URL.
@@ -82,8 +96,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
@@ -94,10 +107,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
@@ -132,16 +143,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
/**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
}
@@ -153,8 +164,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path, minSplits)(cm)
}
@@ -166,8 +176,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path)(cm)
}
@@ -183,8 +192,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -199,8 +208,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
@@ -212,8 +221,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -224,8 +233,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path,
inputFormatClass, keyClass, valueClass))
}
@@ -240,8 +249,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
kClass: Class[K],
vClass: Class[V],
conf: Configuration): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
}
@@ -254,15 +263,15 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
/** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[T] = first.classManifest
+ implicit val cm: ClassTag[T] = first.classTag
sc.union(rdds)(cm)
}
@@ -270,9 +279,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[(K, V)] = first.classManifest
- implicit val kcm: ClassManifest[K] = first.kManifest
- implicit val vcm: ClassManifest[V] = first.vManifest
+ implicit val cm: ClassTag[(K, V)] = first.classTag
+ implicit val kcm: ClassTag[K] = first.kClassTag
+ implicit val vcm: ClassTag[V] = first.vClassTag
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
@@ -385,34 +394,47 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
- * be a HDFS path if running on a cluster. If the directory does not exist, it will
- * be created. If the directory exists and useExisting is set to true, then the
- * exisiting directory will be used. Otherwise an exception will be thrown to
- * prevent accidental overriding of checkpoint files in the existing directory.
- */
- def setCheckpointDir(dir: String, useExisting: Boolean) {
- sc.setCheckpointDir(dir, useExisting)
- }
-
- /**
- * Set the directory under which RDDs are going to be checkpointed. The directory must
- * be a HDFS path if running on a cluster. If the directory does not exist, it will
- * be created. If the directory exists, an exception will be thrown to prevent accidental
- * overriding of checkpoint files.
+ * be a HDFS path if running on a cluster.
*/
def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir)
}
protected def checkpointFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
new JavaRDD(sc.checkpointFile(path))
}
+
+ /**
+ * Return a copy of this JavaSparkContext's configuration. The configuration ''cannot'' be
+ * changed at runtime.
+ */
+ def getConf: SparkConf = sc.getConf
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def setCallSite(site: String) {
+ sc.setCallSite(site)
+ }
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def clearCallSite() {
+ sc.clearCallSite()
+ }
}
object JavaSparkContext {
implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc)
implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc
+
+ /**
+ * Find the JAR from which a given class was loaded, to make it easy for users to pass
+ * their JARs to SparkContext.
+ */
+ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
index c9cbce5624..2090efd3b9 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -17,7 +17,6 @@
package org.apache.spark.api.java;
-import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index 2dfda8b09a..bdb01f7670 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
+ def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 528e1c0a7c..aae1349c5e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+ def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index ce368ee01b..537439ef53 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -29,8 +29,8 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
index 44ad559d48..a2d1214fb4 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -28,8 +28,8 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
index ac6178924a..fb1deceab5 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@@ -29,8 +29,8 @@ import java.io.Serializable;
public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index 6d76a8f970..ca485b3cc2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -18,8 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -33,11 +33,11 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index ede7ceefb5..cbe2306026 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -18,8 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -31,11 +31,11 @@ import java.io.Serializable;
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b4d94a56..32cc70e8c9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -22,18 +22,17 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
-
-private[spark] class PythonRDD[T: ClassManifest](
+private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
- command: Seq[String],
+ command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
@@ -42,23 +41,12 @@ private[spark] class PythonRDD[T: ClassManifest](
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
- val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
- accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
- broadcastVars, accumulator)
+ val bufferSize = conf.get("spark.buffer.size", "65536").toInt
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ dataOut.writeUTF(SparkFiles.getRootDirectory)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- for (f <- pythonIncludes) {
- PythonRDD.writeAsPickle(f, dataOut)
- }
+ pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
- printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
- case -3 =>
+ case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
@@ -143,30 +125,30 @@ private[spark] class PythonRDD[T: ClassManifest](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
- case -2 =>
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
- case -1 =>
+ case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
- // read some accumulator updates; let's do that, breaking when we
- // get a negative length record.
- var len2 = stream.readInt()
- while (len2 >= 0) {
- val update = new Array[Byte](len2)
+ // read some accumulator updates:
+ val numAccumulatorUpdates = stream.readInt()
+ (1 to numAccumulatorUpdates).foreach { _ =>
+ val updateLen = stream.readInt()
+ val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
- len2 = stream.readInt()
+
}
- new Array[Byte](0)
+ Array.empty[Byte]
}
} catch {
case eof: EOFException => {
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
-private[spark] object PythonRDD {
-
- /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
- def stripPickle(arr: Array[Byte]) : Array[Byte] = {
- arr.slice(2, arr.length - 1)
- }
+private object SpecialLengths {
+ val END_OF_DATA_SECTION = -1
+ val PYTHON_EXCEPTION_THROWN = -2
+ val TIMING_DATA = -3
+}
- /**
- * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
- * The data format is a 32-bit integer representing the pickled object's length (in bytes),
- * followed by the pickled data.
- *
- * Pickle module:
- *
- * http://docs.python.org/2/library/pickle.html
- *
- * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
- *
- * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
- * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
- *
- * @param elem the object to write
- * @param dOut a data output stream
- */
- def writeAsPickle(elem: Any, dOut: DataOutputStream) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
- val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
- val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t._1))
- dOut.write(PythonRDD.stripPickle(t._2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new SparkException("Unexpected RDD type")
- }
- }
+private[spark] object PythonRDD {
- def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -265,39 +200,41 @@ private[spark] object PythonRDD {
}
} catch {
case eof: EOFException => {}
- case e => throw e
+ case e: Throwable => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ def writeToStream(elem: Any, dataOut: DataOutputStream) {
+ elem match {
+ case bytes: Array[Byte] =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ case pair: (Array[Byte], Array[Byte]) =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ case str: String =>
+ dataOut.writeUTF(str)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
+ }
+ }
+
+ def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
- writeIteratorToPickleFile(items.asScala, filename)
+ writeToFile(items.asScala, filename)
}
- def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
- writeAsPickle(item, file)
+ writeToStream(item, file)
}
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
- implicit val cm : ClassManifest[T] = rdd.elementClassManifest
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
- }
-}
-
-private object Pickle {
- val PROTO: Byte = 0x80.toByte
- val TWO: Byte = 0x02.toByte
- val BINUNICODE: Byte = 'X'
- val STOP: Byte = '.'
- val TUPLE2: Byte = 0x86.toByte
- val EMPTY_LIST: Byte = ']'
- val MARK: Byte = '('
- val APPENDS: Byte = 'e'
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
@@ -313,7 +250,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
Utils.checkHost(serverHost, "Expected hostname")
- val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val bufferSize = SparkEnv.get.conf.get("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 67d45723ba..f291266fcf 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -64,7 +64,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
startDaemon()
new Socket(daemonHost, daemonPort)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
}
@@ -198,7 +198,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}.start()
} catch {
- case e => {
+ case e: Throwable => {
stopDaemon()
throw e
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 43c18294c5..0fc478a419 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -31,8 +31,8 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
override def toString = "Broadcast(" + id + ")"
}
-private[spark]
-class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
+private[spark]
+class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -43,14 +43,14 @@ class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable
private def initialize() {
synchronized {
if (!initialized) {
- val broadcastFactoryClass = System.getProperty(
+ val broadcastFactoryClass = conf.get(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver)
+ broadcastFactory.initialize(isDriver, conf)
initialized = true
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 68bff75b90..fb161ce69d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import org.apache.spark.SparkConf
+
/**
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
@@ -24,7 +26,7 @@ package org.apache.spark.broadcast
* entire Spark job.
*/
private[spark] trait BroadcastFactory {
- def initialize(isDriver: Boolean): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 609464e38d..db596d5fcc 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -19,18 +19,19 @@ package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
import java.net.URL
+import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-import org.apache.spark.{HttpServer, Logging, SparkEnv}
+import org.apache.spark.{SparkConf, HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
-
+
def value = value_
def blockId = BroadcastBlockId(id)
@@ -39,7 +40,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
- if (!isLocal) {
+ if (!isLocal) {
HttpBroadcast.write(id, value_)
}
@@ -63,7 +64,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
+ def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -80,42 +81,51 @@ private object HttpBroadcast extends Logging {
private var serverUri: String = null
private var server: HttpServer = null
+ // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
- private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
+ private var cleaner: MetadataCleaner = null
+
+ private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
- private lazy val compressionCodec = CompressionCodec.createCodec()
+ private var compressionCodec: CompressionCodec = null
- def initialize(isDriver: Boolean) {
+ def initialize(isDriver: Boolean, conf: SparkConf) {
synchronized {
if (!initialized) {
- bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ bufferSize = conf.get("spark.buffer.size", "65536").toInt
+ compress = conf.get("spark.broadcast.compress", "true").toBoolean
if (isDriver) {
- createServer()
+ createServer(conf)
+ conf.set("spark.httpBroadcast.uri", serverUri)
}
- serverUri = System.getProperty("spark.httpBroadcast.uri")
+ serverUri = conf.get("spark.httpBroadcast.uri")
+ cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
+ compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
}
}
}
-
+
def stop() {
synchronized {
if (server != null) {
server.stop()
server = null
}
+ if (cleaner != null) {
+ cleaner.cancel()
+ cleaner = null
+ }
+ compressionCodec = null
initialized = false
- cleaner.cancel()
}
}
- private def createServer() {
- broadcastDir = Utils.createTempDir(Utils.getLocalDir)
+ private def createServer(conf: SparkConf) {
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
- System.setProperty("spark.httpBroadcast.uri", serverUri)
logInfo("Broadcast server started at " + serverUri)
}
@@ -138,10 +148,13 @@ private object HttpBroadcast extends Logging {
def read[T](id: Long): T = {
val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
+ val httpConnection = new URL(url).openConnection()
+ httpConnection.setReadTimeout(httpReadTimeout)
+ val inputStream = httpConnection.getInputStream
if (compress) {
- compressionCodec.compressedInputStream(new URL(url).openStream())
+ compressionCodec.compressedInputStream(inputStream)
} else {
- new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
+ new FastBufferedInputStream(inputStream, bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 073a0a5029..9530938278 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -83,13 +83,13 @@ extends Broadcast[T](id) with Logging with Serializable {
case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
-
+
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
if (receiveBroadcast(id)) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-
+
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
@@ -122,14 +122,14 @@ extends Broadcast[T](id) with Logging with Serializable {
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(metaId) match {
- case Some(x) =>
+ case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0
-
- case None =>
+
+ case None =>
Thread.sleep(500)
}
}
@@ -145,13 +145,13 @@ extends Broadcast[T](id) with Logging with Serializable {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
- case Some(x) =>
+ case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
-
- case None =>
+
+ case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
@@ -166,21 +166,22 @@ private object TorrentBroadcast
extends Logging {
private var initialized = false
-
- def initialize(_isDriver: Boolean) {
+ private var conf: SparkConf = null
+ def initialize(_isDriver: Boolean, conf: SparkConf) {
+ TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
initialized = true
}
}
}
-
+
def stop() {
initialized = false
}
- val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
-
+ lazy val BLOCK_SIZE = conf.get("spark.broadcast.blockSize", "4096").toInt * 1024
+
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
@@ -209,7 +210,7 @@ extends Logging {
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
+ totalBytes: Int,
totalBlocks: Int): T = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
@@ -222,23 +223,23 @@ extends Logging {
}
private[spark] case class TorrentBlock(
- blockID: Int,
- byteArray: Array[Byte])
+ blockID: Int,
+ byteArray: Array[Byte])
extends Serializable
private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
- totalBlocks: Int,
- totalBytes: Int)
+ totalBlocks: Int,
+ totalBytes: Int)
extends Serializable {
-
- @transient var hasBlocks = 0
+
+ @transient var hasBlocks = 0
}
private[spark] class TorrentBroadcastFactory
extends BroadcastFactory {
-
- def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
index 19d393a0db..e38459b883 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy
private[spark] class ApplicationDescription(
val name: String,
- val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
+ val maxCores: Option[Int],
val memoryPerSlave: Int,
val command: Command,
val sparkHome: String,
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
index fcfea96ad6..37dfa7fec0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy
-private[spark] object ExecutorState
- extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
+private[spark] object ExecutorState extends Enumeration {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 0aa8852649..4dfb19ed8a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -190,7 +190,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
/** Creates a SparkContext, which constructs a Client to interact with our cluster. */
def createClient() = {
if (sc != null) { sc.stop() }
- // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
+ // Counter-hack: Because of a hack in SparkEnv#create() that changes this
// property, we need to reset it.
System.setProperty("spark.driver.port", "0")
sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
@@ -417,4 +417,4 @@ private[spark] object Docker extends Logging {
"docker ps -l -q".!(ProcessLogger(line => id = line))
new DockerId(id)
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index a724900943..ffc0cb0903 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -22,7 +22,7 @@ import akka.actor.ActorSystem
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
import org.apache.spark.util.Utils
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import scala.collection.mutable.ArrayBuffer
@@ -34,16 +34,17 @@ import scala.collection.mutable.ArrayBuffer
*/
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
-
+
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
-
+
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
+ val conf = new SparkConf(false)
+ val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
val masters = Array(masterUrl)
@@ -55,16 +56,19 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
workerActorSystems += workerSystem
}
- return masters
+ masters
}
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
+ // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
+ // This is unfortunate, but for now we just comment it out.
workerActorSystems.foreach(_.shutdown())
- workerActorSystems.foreach(_.awaitTermination())
-
+ //workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
- masterActorSystems.foreach(_.awaitTermination())
+ //masterActorSystems.foreach(_.awaitTermination())
+ masterActorSystems.clear()
+ workerActorSystems.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index fc1537f796..27dc42bf7e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -34,10 +34,10 @@ class SparkHadoopUtil {
UserGroupInformation.setConfiguration(conf)
def runAsUser(user: String)(func: () => Unit) {
- // if we are already running as the user intended there is no reason to do the doAs. It
+ // if we are already running as the user intended there is no reason to do the doAs. It
// will actually break secure HDFS access as it doesn't fill in the credentials. Also if
- // the user is UNKNOWN then we shouldn't be creating a remote unknown user
- // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
+ // the user is UNKNOWN then we shouldn't be creating a remote unknown user
+ // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
// in SparkContext.
val currentUser = Option(System.getProperty("user.name")).
getOrElse(SparkContext.SPARK_UNKNOWN_USER)
@@ -67,11 +67,15 @@ class SparkHadoopUtil {
}
object SparkHadoopUtil {
+
private val hadoop = {
- val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ val yarnMode = java.lang.Boolean.valueOf(
+ System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
+ .newInstance()
+ .asInstanceOf[SparkHadoopUtil]
} catch {
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index 77422f61ec..481026eaa2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -19,21 +19,18 @@ package org.apache.spark.deploy.client
import java.util.concurrent.TimeoutException
+import scala.concurrent.Await
+import scala.concurrent.duration._
+
import akka.actor._
-import akka.actor.Terminated
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
-import akka.remote.RemoteClientDisconnected
-import akka.remote.RemoteClientLifeCycleEvent
-import akka.remote.RemoteClientShutdown
-import akka.dispatch.Await
-
-import org.apache.spark.Logging
+import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
+
+import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
-
+import org.apache.spark.util.AkkaUtils
/**
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
@@ -45,24 +42,26 @@ private[spark] class Client(
actorSystem: ActorSystem,
masterUrls: Array[String],
appDescription: ApplicationDescription,
- listener: ClientListener)
+ listener: ClientListener,
+ conf: SparkConf)
extends Logging {
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
+ var masterAddress: Address = null
var actor: ActorRef = null
var appId: String = null
var registered = false
var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
- var master: ActorRef = null
- var masterAddress: Address = null
+ var master: ActorSelection = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
try {
registerWithMaster()
} catch {
@@ -76,7 +75,7 @@ private[spark] class Client(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterApplication(appDescription)
}
}
@@ -84,6 +83,7 @@ private[spark] class Client(
def registerWithMaster() {
tryRegisterAllMasters()
+ import context.dispatcher
var retries = 0
lazy val retryTimer: Cancellable =
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
@@ -102,10 +102,19 @@ private[spark] class Client(
def changeMaster(url: String) {
activeMasterUrl = url
- master = context.actorFor(Master.toAkkaUrl(url))
- masterAddress = master.path.address
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(host, port) =>
+ Address("akka.tcp", Master.systemName, host, port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
+ }
+
+ private def isPossibleMaster(remoteUrl: Address) = {
+ masterUrls.map(s => Master.toAkkaUrl(s))
+ .map(u => AddressFromURIString(u).hostPort)
+ .contains(remoteUrl.hostPort)
}
override def receive = {
@@ -135,22 +144,16 @@ private[spark] class Client(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl)
alreadyDisconnected = false
sender ! MasterChangeAcknowledged(appId)
- case Terminated(actor_) if actor_ == master =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
+ case DisassociatedEvent(_, address, _) if address == masterAddress =>
+ logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
- case RemoteClientShutdown(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
+ case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) =>
+ logWarning(s"Could not connect to $address: $cause")
case StopClient =>
markDead()
@@ -184,7 +187,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout(conf)
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 5b62d3ba6c..28ebbdc66b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy.client
import org.apache.spark.util.{Utils, AkkaUtils}
-import org.apache.spark.{Logging}
+import org.apache.spark.{SparkConf, SparkContext, Logging}
import org.apache.spark.deploy.{Command, ApplicationDescription}
private[spark] object TestClient {
@@ -45,11 +45,13 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
- val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
+ val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
+ conf = new SparkConf)
val desc = new ApplicationDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
+ "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
+ "dummy-spark-home", "ignored")
val listener = new TestListener
- val client = new Client(actorSystem, Array(url), desc, listener)
+ val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf)
client.start()
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 5150b7c7de..3e26379166 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -28,7 +28,8 @@ private[spark] class ApplicationInfo(
val desc: ApplicationDescription,
val submitDate: Date,
val driver: ActorRef,
- val appUiUrl: String)
+ val appUiUrl: String,
+ defaultCores: Int)
extends Serializable {
@transient var state: ApplicationState.Value = _
@@ -81,7 +82,9 @@ private[spark] class ApplicationInfo(
}
}
- def coresLeft: Int = desc.maxCores - coresGranted
+ private val myMaxCores = desc.maxCores.getOrElse(defaultCores)
+
+ def coresLeft: Int = myMaxCores - coresGranted
private var _retryCount = 0
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index fedf879eff..67e6c5d66a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object ApplicationState
- extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") {
+private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index c0849ef324..043945a211 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -65,7 +65,7 @@ private[spark] class FileSystemPersistenceEngine(
(apps, workers)
}
- private def serializeIntoFile(file: File, value: Serializable) {
+ private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
@@ -77,13 +77,13 @@ private[spark] class FileSystemPersistenceEngine(
out.close()
}
- def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
dis.readFully(fileData)
dis.close()
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index cd916672ac..6617b7100f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,21 +17,19 @@
package org.apache.spark.deploy.master
-import java.util.Date
import java.text.SimpleDateFormat
+import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.actor.Terminated
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import akka.util.duration._
-import akka.util.{Duration, Timeout}
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.MasterMessages._
@@ -40,12 +38,16 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+ import context.dispatcher // to use Akka's scheduler.schedule()
+
+ val conf = new SparkConf
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
- val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
- val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
- val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
- val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
- val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
+ val WORKER_TIMEOUT = conf.get("spark.worker.timeout", "60").toLong * 1000
+ val RETAINED_APPLICATIONS = conf.get("spark.deploy.retainedApplications", "200").toInt
+ val REAPER_ITERATIONS = conf.get("spark.dead.worker.persistence", "15").toInt
+ val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
+ val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
@@ -61,12 +63,10 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val waitingApps = new ArrayBuffer[ApplicationInfo]
val completedApps = new ArrayBuffer[ApplicationInfo]
- var firstApp: Option[ApplicationInfo] = None
-
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master")
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications")
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -88,12 +88,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
- val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
+ val spreadOutApps = conf.getBoolean("spark.deploy.spreadOut", true)
+
+ // Default maxCores for applications that don't specify it (i.e. pass Int.MaxValue)
+ val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue)
+ if (defaultCores < 1) {
+ throw new SparkException("spark.deploy.defaultCores must be positive")
+ }
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
@@ -105,7 +111,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
persistenceEngine = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
- new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
+ new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf)
case "FILESYSTEM" =>
logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
@@ -113,13 +119,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
new BlackHolePersistenceEngine()
}
- leaderElectionAgent = context.actorOf(Props(
- RECOVERY_MODE match {
+ leaderElectionAgent = RECOVERY_MODE match {
case "ZOOKEEPER" =>
- new ZooKeeperLeaderElectionAgent(self, masterUrl)
+ context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf))
case _ =>
- new MonarchyLeaderAgent(self)
- }))
+ context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
+ }
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -142,9 +147,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
RecoveryState.ALIVE
else
RecoveryState.RECOVERING
-
logInfo("I have been elected leader! New state: " + state)
-
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedWorkers)
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
@@ -156,7 +159,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
System.exit(0)
}
- case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
+ case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
@@ -164,9 +167,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
} else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
+ sender, workerWebUiPort, publicAddress)
registerWorker(worker)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addWorker(worker)
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
@@ -181,7 +184,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val app = createApplication(description, sender)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addApplication(app)
sender ! RegisteredApplication(app.id, masterUrl)
schedule()
@@ -257,23 +259,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (canCompleteRecovery) { completeRecovery() }
}
- case Terminated(actor) => {
- // The disconnected actor could've been either a worker or an app; remove whichever of
- // those we have an entry for in the corresponding actor hashmap
- actorToWorker.get(actor).foreach(removeWorker)
- actorToApp.get(actor).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientDisconnected(transport, address) => {
- // The disconnected client could've been either a worker or an app; remove whichever it was
- addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientShutdown(transport, address) => {
+ case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
+ logInfo(s"$address got disassociated, removing it.")
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
@@ -444,7 +432,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ new ApplicationInfo(
+ now, newApplicationId(date), desc, date, driver, desc.appUiUrl, defaultCores)
}
def registerApplication(app: ApplicationInfo): Unit = {
@@ -459,14 +448,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
idToApp(app.id) = app
actorToApp(app.driver) = app
addressToApp(appAddress) = app
- if (firstApp == None) {
- firstApp = Some(app)
- }
- // TODO: What is firstApp?? Can we remove it?
- val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
- if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
- logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
- }
waitingApps += app
}
@@ -523,41 +504,42 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
removeWorker(worker)
} else {
if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT))
- workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
+ workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
}
}
}
}
private[spark] object Master {
- private val systemName = "sparkMaster"
+ val systemName = "sparkMaster"
private val actorName = "Master"
- private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+ val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
- val args = new MasterArguments(argStrings)
- val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
+ val conf = new SparkConf
+ val args = new MasterArguments(argStrings, conf)
+ val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
actorSystem.awaitTermination()
}
- /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
- "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
}
- def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
- val timeoutDuration = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
- implicit val timeout = Timeout(timeoutDuration)
- val respFuture = actor ? RequestWebUIPort // ask pattern
- val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
+ : (ActorSystem, Int, Int) =
+ {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val timeout = AkkaUtils.askTimeout(conf)
+ val respFuture = actor.ask(RequestWebUIPort)(timeout)
+ val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
(actorSystem, boundPort, resp.webUIBoundPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 9d89b455fb..e7f3224091 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -18,16 +18,17 @@
package org.apache.spark.deploy.master
import org.apache.spark.util.{Utils, IntParam}
+import org.apache.spark.SparkConf
/**
* Command-line parser for the master.
*/
-private[spark] class MasterArguments(args: Array[String]) {
+private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
-
- // Check for settings in environment variables
+
+ // Check for settings in environment variables
if (System.getenv("SPARK_MASTER_HOST") != null) {
host = System.getenv("SPARK_MASTER_HOST")
}
@@ -37,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String]) {
if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt
}
- if (System.getProperty("master.ui.port") != null) {
- webUiPort = System.getProperty("master.ui.port").toInt
+ if (conf.contains("master.ui.port")) {
+ webUiPort = conf.get("master.ui.port").toInt
}
parse(args.toList)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
index b91be821f0..256a5a7c28 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object RecoveryState
- extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
-
+private[spark] object RecoveryState extends Enumeration {
type MasterState = Value
val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
index 81e15c534f..999090ad74 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
@@ -18,12 +18,12 @@
package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
-import scala.concurrent.ops._
-import org.apache.spark.Logging
import org.apache.zookeeper._
-import org.apache.zookeeper.data.Stat
import org.apache.zookeeper.Watcher.Event.KeeperState
+import org.apache.zookeeper.data.Stat
+
+import org.apache.spark.{SparkConf, Logging}
/**
* Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
@@ -33,10 +33,11 @@ import org.apache.zookeeper.Watcher.Event.KeeperState
* informed via zkDown().
*
* Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
- * times or a semantic exception is thrown (e.g.., "node already exists").
+ * times or a semantic exception is thrown (e.g., "node already exists").
*/
-private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
- val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
+private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher,
+ conf: SparkConf) extends Logging {
+ val ZK_URL = conf.get("spark.deploy.zookeeper.url", "")
val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
val ZK_TIMEOUT_MILLIS = 30000
@@ -103,6 +104,7 @@ private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) ext
connectToZooKeeper()
case KeeperState.Disconnected =>
logWarning("ZooKeeper disconnected, will retry...")
+ case s => // Do nothing
}
}
}
@@ -179,7 +181,7 @@ private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) ext
} catch {
case e: KeeperException.NoNodeException => throw e
case e: KeeperException.NodeExistsException => throw e
- case e if n > 0 =>
+ case e: Exception if n > 0 =>
logError("ZooKeeper exception, " + n + " more retries...", e)
Thread.sleep(RETRY_WAIT_MILLIS)
retry(fn, n-1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index c8d34f25e2..0b36ef6005 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object WorkerState
- extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") {
-
+private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 7809013e83..77c23fb9fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -21,16 +21,17 @@ import akka.actor.ActorRef
import org.apache.zookeeper._
import org.apache.zookeeper.Watcher.Event.EventType
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.deploy.master.MasterMessages._
-import org.apache.spark.Logging
-private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
+ masterUrl: String, conf: SparkConf)
extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
- val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
+ val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
private val watcher = new ZooKeeperWatcher()
- private val zk = new SparkZooKeeperSession(this)
+ private val zk = new SparkZooKeeperSession(this, conf)
private var status = LeadershipStatus.NOT_LEADER
private var myLeaderFile: String = _
private var leaderUrl: String = _
@@ -105,7 +106,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, mas
// We found a different master file pointing to this process.
// This can happen in the following two cases:
// (1) The master process was restarted on the same node.
- // (2) The ZK server died between creating the node and returning the name of the node.
+ // (2) The ZK server died between creating the file and returning the name of the file.
// For this case, we will end up creating a second file, and MUST explicitly delete the
// first one, since our ZK session is still open.
// Note that this deletion will cause a NodeDeleted event to be fired so we check again for
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index a0233a7271..52000d4f9c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,19 +17,19 @@
package org.apache.spark.deploy.master
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import org.apache.zookeeper._
import akka.serialization.Serialization
-class ZooKeeperPersistenceEngine(serialization: Serialization)
+class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
extends PersistenceEngine
with SparkZooKeeperWatcher
with Logging
{
- val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+ val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
- val zk = new SparkZooKeeperSession(this)
+ val zk = new SparkZooKeeperSession(this, conf)
zk.connect()
@@ -70,15 +70,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization)
(apps, workers)
}
- private def serializeIntoFile(path: String, value: Serializable) {
+ private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create(path, serialized, CreateMode.PERSISTENT)
}
- def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
val fileData = zk.getData("/spark/master_status/" + filename)
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f4e574d15d..dbb0cb90f5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -17,31 +17,28 @@
package org.apache.spark.deploy.master.ui
+import scala.concurrent.Await
import scala.xml.Node
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.duration._
-
import javax.servlet.http.HttpServletRequest
-
import net.liftweb.json.JsonAST.JValue
-import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorInfo
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
private[spark] class ApplicationPage(parent: MasterWebUI) {
val master = parent.masterActorRef
- implicit val timeout = parent.timeout
+ val timeout = parent.timeout
/** Executor details for a particular application */
def renderJson(request: HttpServletRequest): JValue = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
@@ -52,7 +49,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) {
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
index d7a57229b0..4ef762892c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -17,37 +17,33 @@
package org.apache.spark.deploy.master.ui
-import javax.servlet.http.HttpServletRequest
-
+import scala.concurrent.Await
import scala.xml.Node
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.duration._
-
+import javax.servlet.http.HttpServletRequest
import net.liftweb.json.JsonAST.JValue
-import org.apache.spark.deploy.DeployWebUI
+import org.apache.spark.deploy.{DeployWebUI, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
private[spark] class IndexPage(parent: MasterWebUI) {
val master = parent.masterActorRef
- implicit val timeout = parent.timeout
+ val timeout = parent.timeout
def renderJson(request: HttpServletRequest): JValue = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
JsonProtocol.writeMasterState(state)
}
/** Index view listing applications and executors */
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index f4df729e87..ead35662fc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -17,25 +17,21 @@
package org.apache.spark.deploy.master.ui
-import akka.util.Duration
-
import javax.servlet.http.HttpServletRequest
-
import org.eclipse.jetty.server.{Handler, Server}
-import org.apache.spark.{Logging}
+import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
import org.apache.spark.ui.JettyUtils
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Web UI server for the standalone master.
*/
private[spark]
class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
- implicit val timeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout(master.conf)
val host = Utils.localHostName()
val port = requestedPort
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 216d9d44ac..fcaf4e92b1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -17,23 +17,22 @@
package org.apache.spark.deploy.worker
+import java.io.File
import java.text.SimpleDateFormat
import java.util.Date
-import java.io.File
import scala.collection.mutable.HashMap
+import scala.concurrent.duration._
import akka.actor._
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
-import akka.util.duration._
-
-import org.apache.spark.Logging
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* @param masterUrls Each url should look like spark://host:port.
@@ -45,8 +44,10 @@ private[spark] class Worker(
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDirPath: String = null)
+ workDirPath: String = null,
+ val conf: SparkConf)
extends Actor with Logging {
+ import context.dispatcher
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
@@ -54,7 +55,7 @@ private[spark] class Worker(
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
- val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+ val HEARTBEAT_MILLIS = conf.get("spark.worker.timeout", "60").toLong * 1000 / 4
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
@@ -63,7 +64,8 @@ private[spark] class Worker(
var masterIndex = 0
val masterLock: Object = new Object()
- var master: ActorRef = null
+ var master: ActorSelection = null
+ var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
@volatile var registered = false
@@ -82,7 +84,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
- val metricsSystem = MetricsSystem.createMetricsSystem("worker")
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
val workerSource = new WorkerSource(this)
def coresFree: Int = cores - coresUsed
@@ -114,7 +116,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
-
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
registerWithMaster()
@@ -126,9 +128,13 @@ private[spark] class Worker(
masterLock.synchronized {
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
- master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(_host, _port) =>
+ Address("akka.tcp", Master.systemName, _host, _port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
connected = true
}
}
@@ -136,7 +142,7 @@ private[spark] class Worker(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
publicAddress)
}
@@ -175,7 +181,6 @@ private[spark] class Worker(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl, masterWebUiUrl)
val execs = executors.values.
@@ -234,13 +239,8 @@ private[spark] class Worker(
}
}
- case Terminated(actor_) if actor_ == master =>
- masterDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == master.path.address =>
- masterDisconnected()
-
- case RemoteClientShutdown(transport, address) if address == master.path.address =>
+ case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
+ logInfo(s"$x Disassociated !")
masterDisconnected()
case RequestWorkerState => {
@@ -267,6 +267,7 @@ private[spark] class Worker(
}
private[spark] object Worker {
+
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
@@ -275,13 +276,16 @@ private[spark] object Worker {
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
- masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
- : (ActorSystem, Int) = {
+ masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
+ : (ActorSystem, Int) =
+ {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
- masterUrls, workDir)), name = "Worker")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
+ conf = conf)
+ actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
+ masterUrls, workDir, conf), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
index d2d3617498..0d59048313 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -21,9 +21,10 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import akka.dispatch.Await
+import scala.concurrent.duration._
+import scala.concurrent.Await
+
import akka.pattern.ask
-import akka.util.duration._
import net.liftweb.json.JsonAST.JValue
@@ -41,13 +42,13 @@ private[spark] class IndexPage(parent: WorkerWebUI) {
def renderJson(request: HttpServletRequest): JValue = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, 30 seconds)
+ val workerState = Await.result(stateFuture, timeout)
JsonProtocol.writeWorkerState(workerState)
}
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, 30 seconds)
+ val workerState = Await.result(stateFuture, timeout)
val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs")
val runningExecutorTable =
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 800f1cafcc..c382034c99 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -17,20 +17,16 @@
package org.apache.spark.deploy.worker.ui
-import akka.util.{Duration, Timeout}
-
-import java.io.{FileInputStream, File}
+import java.io.File
import javax.servlet.http.HttpServletRequest
-
import org.eclipse.jetty.server.{Handler, Server}
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.worker.Worker
-import org.apache.spark.{Logging}
-import org.apache.spark.ui.JettyUtils
+import org.apache.spark.ui.{JettyUtils, UIUtils}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.UIUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Web UI server for the standalone worker.
@@ -38,11 +34,10 @@ import org.apache.spark.util.Utils
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
extends Logging {
- implicit val timeout = Timeout(
- Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
+ val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
- System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt)
+ worker.conf.get("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt)
var server: Option[Server] = None
var boundPort: Option[Int] = None
@@ -145,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
<body>
{linkToMaster}
<div>
- <div style="float:left;width:40%">{backButton}</div>
+ <div style="float:left; margin-right:10px">{backButton}</div>
<div style="float:left;">{range}</div>
- <div style="float:right;">{nextButton}</div>
+ <div style="float:right; margin-left:10px">{nextButton}</div>
</div>
<br />
- <div style="height:500px;overflow:auto;padding:5px;">
+ <div style="height:500px; overflow:auto; padding:5px;">
<pre>{logText}</pre>
</div>
</body>
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 8332631838..53a2b94a52 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,15 +19,14 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
-import akka.actor.{ActorRef, Actor, Props, Terminated}
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
+import akka.actor._
+import akka.remote._
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, SparkContext, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
@@ -40,14 +39,13 @@ private[spark] class CoarseGrainedExecutorBackend(
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
- var driver: ActorRef = null
+ var driver: ActorSelection = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorFor(driverUrl)
+ driver = context.actorSelection(driverUrl)
driver ! RegisterExecutor(executorId, hostPort, cores)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
override def receive = {
@@ -77,8 +75,8 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.killTask(taskId)
}
- case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
- logError("Driver terminated or disconnected! Shutting down.")
+ case x: DisassociatedEvent =>
+ logError(s"Driver $x disassociated! Shutting down.")
System.exit(1)
case StopExecutor =>
@@ -99,12 +97,13 @@ private[spark] object CoarseGrainedExecutorBackend {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
+ indestructible = true, conf = new SparkConf)
// set it
val sparkHostPort = hostname + ":" + boundPort
- System.setProperty("spark.hostPort", sparkHostPort)
- val actor = actorSystem.actorOf(
- Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+// conf.set("spark.hostPort", sparkHostPort)
+ actorSystem.actorOf(
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5c9bb9db1c..e51d274d33 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -48,8 +48,6 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
- initLogging()
-
// No ip or host:port - just hostname
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
// must not have port specified.
@@ -58,16 +56,17 @@ private[spark] class Executor(
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
- // Set spark.* system properties from executor arg
- for ((key, value) <- properties) {
- System.setProperty(key, value)
- }
+ // Set spark.* properties from executor arg
+ val conf = new SparkConf(false)
+ conf.setAll(properties)
// If we are in yarn mode, systems can have different disk layouts so we must set it
// to what Yarn on this system said was available. This will be used later when SparkEnv
// created.
- if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) {
- System.setProperty("spark.local.dir", getYarnLocalDirs())
+ if (java.lang.Boolean.valueOf(
+ System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))))
+ {
+ conf.set("spark.local.dir", getYarnLocalDirs())
}
// Create our ClassLoader and set it on this thread
@@ -108,7 +107,7 @@ private[spark] class Executor(
// Initialize Spark environment (using system properties read above)
private val env = {
if (!isLocal) {
- val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+ val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
isDriver = false, isLocal = false)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
@@ -121,7 +120,7 @@ private[spark] class Executor(
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
- env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
// Start worker thread pool
@@ -142,11 +141,6 @@ private[spark] class Executor(
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
- // We remove the task also in the finally block in TaskRunner.run.
- // The reason we need to remove it here is because killTask might be called before the task
- // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
- // idempotent.
- runningTasks.remove(taskId)
}
}
@@ -168,6 +162,8 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
+ object TaskKilledException extends Exception
+
@volatile private var killed = false
@volatile private var task: Task[Any] = _
@@ -201,9 +197,11 @@ private[spark] class Executor(
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
- logInfo("Executor killed task " + taskId)
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- return
+ // Throw an exception rather than returning, because returning within a try{} block
+ // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
+ // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
+ // for the task.
+ throw TaskKilledException
}
attemptedTask = Some(task)
@@ -217,23 +215,25 @@ private[spark] class Executor(
// If the task has been killed, let's fail it.
if (task.killed) {
- logInfo("Executor killed task " + taskId)
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- return
+ throw TaskKilledException
}
+ val resultSer = SparkEnv.get.serializer.newInstance()
+ val beforeSerialization = System.currentTimeMillis()
+ val valueBytes = resultSer.serialize(value)
+ val afterSerialization = System.currentTimeMillis()
+
for (m <- task.metrics) {
m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
m.jvmGCTime = gcTime - startGCTime
+ m.resultSerializationTime = (afterSerialization - beforeSerialization).toInt
}
- // TODO I'd also like to track the time it takes to serialize the task results, but that is
- // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
- // custom serialized format, we could just change the relevants bytes in the byte buffer
+
val accumUpdates = Accumulators.values
- val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+ val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
@@ -257,6 +257,11 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
+ case TaskKilledException => {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ }
+
case t: Throwable => {
val serviceTime = (System.currentTimeMillis() - taskStart).toInt
val metrics = attemptedTask.flatMap(t => t.metrics)
@@ -299,7 +304,7 @@ private[spark] class Executor(
* new classes defined by the REPL as the user types code
*/
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
- val classUri = System.getProperty("spark.repl.class.uri")
+ val classUri = conf.get("spark.repl.class.uri", null)
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
try {
@@ -327,12 +332,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 0b4892f98f..bb1471d9ee 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -44,6 +44,11 @@ class TaskMetrics extends Serializable {
var jvmGCTime: Long = _
/**
+ * Amount of time spent serializing the task result
+ */
+ var resultSerializationTime: Long = _
+
+ /**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
*/
var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
@@ -61,50 +66,53 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
- * Time when shuffle finishs
+ * Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
/**
- * Total number of blocks fetched in a shuffle (remote or local)
+ * Number of blocks fetched in this shuffle by this task (remote or local)
*/
var totalBlocksFetched: Int = _
/**
- * Number of remote blocks fetched in a shuffle
+ * Number of remote blocks fetched in this shuffle by this task
*/
var remoteBlocksFetched: Int = _
/**
- * Local blocks fetched in a shuffle
+ * Number of local blocks fetched in this shuffle by this task
*/
var localBlocksFetched: Int = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch data
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
*/
var fetchWaitTime: Long = _
/**
- * The total amount of time for all the shuffle fetches. This adds up time from overlapping
- * shuffles, so can be longer than task time
+ * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+ * input blocks. Since block fetches are both pipelined and parallelized, this can
+ * exceed fetchWaitTime and executorRunTime.
*/
var remoteFetchTime: Long = _
/**
- * Total number of remote bytes read from a shuffle
+ * Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
- * Number of bytes written for a shuffle
+ * Number of bytes written for the shuffle by this task
*/
var shuffleBytesWritten: Long = _
/**
- * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 570a979b56..a1e98845f6 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -22,6 +22,7 @@ import java.io.{InputStream, OutputStream}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream}
+import org.apache.spark.{SparkEnv, SparkConf}
/**
@@ -37,15 +38,15 @@ trait CompressionCodec {
private[spark] object CompressionCodec {
-
- def createCodec(): CompressionCodec = {
- createCodec(System.getProperty(
+ def createCodec(conf: SparkConf): CompressionCodec = {
+ createCodec(conf, conf.get(
"spark.io.compression.codec", classOf[LZFCompressionCodec].getName))
}
- def createCodec(codecName: String): CompressionCodec = {
- Class.forName(codecName, true, Thread.currentThread.getContextClassLoader)
- .newInstance().asInstanceOf[CompressionCodec]
+ def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
+ val ctor = Class.forName(codecName, true, Thread.currentThread.getContextClassLoader)
+ .getConstructor(classOf[SparkConf])
+ ctor.newInstance(conf).asInstanceOf[CompressionCodec]
}
}
@@ -53,7 +54,7 @@ private[spark] object CompressionCodec {
/**
* LZF implementation of [[org.apache.spark.io.CompressionCodec]].
*/
-class LZFCompressionCodec extends CompressionCodec {
+class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
override def compressedOutputStream(s: OutputStream): OutputStream = {
new LZFOutputStream(s).setFinishBlockOnFlush(true)
@@ -67,10 +68,10 @@ class LZFCompressionCodec extends CompressionCodec {
* Snappy implementation of [[org.apache.spark.io.CompressionCodec]].
* Block size can be configured by spark.io.compression.snappy.block.size.
*/
-class SnappyCompressionCodec extends CompressionCodec {
+class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
override def compressedOutputStream(s: OutputStream): OutputStream = {
- val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt
+ val blockSize = conf.get("spark.io.compression.snappy.block.size", "32768").toInt
new SnappyOutputStream(s, blockSize)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index caab748d60..6f9f29969e 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -26,7 +26,6 @@ import scala.util.matching.Regex
import org.apache.spark.Logging
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
- initLogging()
val DEFAULT_PREFIX = "*"
val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index bec0c83be8..9930537b34 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -62,10 +62,10 @@ import org.apache.spark.metrics.source.Source
*
* [options] is the specific property of this source or sink.
*/
-private[spark] class MetricsSystem private (val instance: String) extends Logging {
- initLogging()
+private[spark] class MetricsSystem private (val instance: String,
+ conf: SparkConf) extends Logging {
- val confFile = System.getProperty("spark.metrics.conf")
+ val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
val sinks = new mutable.ArrayBuffer[Sink]
@@ -159,5 +159,6 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance)
+ def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
+ new MetricsSystem(instance, conf)
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 9c2fee4023..46c40d0a2a 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -31,13 +31,13 @@ import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.ArrayBuffer
-import akka.dispatch.{Await, Promise, ExecutionContext, Future}
-import akka.util.Duration
-import akka.util.duration._
-import org.apache.spark.util.Utils
+import scala.concurrent.{Await, Promise, ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
+import org.apache.spark.util.Utils
-private[spark] class ConnectionManager(port: Int) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
class MessageStatus(
val message: Message,
@@ -54,22 +54,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
- System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
- System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
- System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ conf.get("spark.core.connection.handler.threads.min", "20").toInt,
+ conf.get("spark.core.connection.handler.threads.max", "60").toInt,
+ conf.get("spark.core.connection.handler.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val handleReadWriteExecutor = new ThreadPoolExecutor(
- System.getProperty("spark.core.connection.io.threads.min","4").toInt,
- System.getProperty("spark.core.connection.io.threads.max","32").toInt,
- System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ conf.get("spark.core.connection.io.threads.min", "4").toInt,
+ conf.get("spark.core.connection.io.threads.max", "32").toInt,
+ conf.get("spark.core.connection.io.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
- System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
- System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
- System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ conf.get("spark.core.connection.connect.threads.min", "1").toInt,
+ conf.get("spark.core.connection.connect.threads.max", "8").toInt,
+ conf.get("spark.core.connection.connect.threads.keepalive", "60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
@@ -594,7 +594,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private[spark] object ConnectionManager {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999)
+ val manager = new ConnectionManager(9999, new SparkConf)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
index 8d9ad9604d..4f5742d29b 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -25,8 +25,8 @@ import scala.io.Source
import java.nio.ByteBuffer
import java.net.InetAddress
-import akka.dispatch.Await
-import akka.util.duration._
+import scala.concurrent.Await
+import scala.concurrent.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 781715108b..1c9d6030d6 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -19,19 +19,19 @@ package org.apache.spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
+import org.apache.spark.SparkConf
private[spark] object ReceiverTest {
-
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999)
+ val manager = new ConnectionManager(9999, new SparkConf)
println("Started connection manager with id = " + manager.id)
-
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
/*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
- val buffer = ByteBuffer.wrap("response".getBytes())
+ val buffer = ByteBuffer.wrap("response".getBytes)
Some(Message.createBufferMessage(buffer, msg.id))
})
- Thread.currentThread.join()
+ Thread.currentThread.join()
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index 777574980f..dcbd183c88 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -19,29 +19,29 @@ package org.apache.spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
+import org.apache.spark.SparkConf
private[spark] object SenderTest {
-
def main(args: Array[String]) {
-
+
if (args.length < 2) {
println("Usage: SenderTest <target host> <target port>")
System.exit(1)
}
-
+
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
- val manager = new ConnectionManager(0)
+ val manager = new ConnectionManager(0, new SparkConf)
println("Started connection manager with id = " + manager.id)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
-
- val size = 100 * 1024 * 1024
+
+ val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -50,7 +50,7 @@ private[spark] object SenderTest {
val count = 100
(0 until count).foreach(i => {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
- val startTime = System.currentTimeMillis
+ val startTime = System.currentTimeMillis
/*println("Started timer at " + startTime)*/
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
case Some(response) =>
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index b1e1576dad..b729eb11c5 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -23,20 +23,20 @@ import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.util.CharsetUtil
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, SparkConf, Logging}
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
import org.apache.spark.storage.BlockId
-private[spark] class ShuffleCopier extends Logging {
+private[spark] class ShuffleCopier(conf: SparkConf) extends Logging {
def getBlock(host: String, port: Int, blockId: BlockId,
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val connectTimeout = conf.get("spark.shuffle.netty.connect.timeout", "60000").toInt
val fc = new FileClient(handler, connectTimeout)
try {
@@ -104,10 +104,10 @@ private[spark] object ShuffleCopier extends Logging {
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
- val tasks = (for (i <- Range(0, threads)) yield {
+ val tasks = (for (i <- Range(0, threads)) yield {
Executors.callable(new Runnable() {
def run() {
- val copier = new ShuffleCopier()
+ val copier = new ShuffleCopier(new SparkConf)
copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index faaf837be0..d1c74a5063 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
+import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
@@ -28,7 +29,7 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
+class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
* Returns a future for counting the number of elements in the RDD.
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 44ea573a7c..424354ae16 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.{BlockId, BlockManager}
@@ -25,7 +27,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P
}
private[spark]
-class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
+class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 0de22f0e06..87b950ba43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
import org.apache.spark._
@@ -43,7 +44,7 @@ class CartesianPartition(
}
private[spark]
-class CartesianRDD[T: ClassManifest, U:ClassManifest](
+class CartesianRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1 : RDD[T],
var rdd2 : RDD[U])
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index d3033ea4a6..6d4f46125f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,15 +17,13 @@
package org.apache.spark.rdd
+import java.io.IOException
+import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.io.{NullWritable, BytesWritable}
-import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
-import java.io.{File, IOException, EOFException}
-import java.text.NumberFormat
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -33,9 +31,11 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
-class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
+class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
+ val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration))
+
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getPartitions: Array[Partition] = {
@@ -67,7 +67,7 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
- CheckpointRDD.readFromFile(file, context)
+ CheckpointRDD.readFromFile(file, broadcastedConf, context)
}
override def checkpoint() {
@@ -76,15 +76,18 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
}
private[spark] object CheckpointRDD extends Logging {
-
def splitIdToFile(splitId: Int): String = {
"part-%05d".format(splitId)
}
- def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ def writeToFile[T](
+ path: String,
+ broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ blockSize: Int = -1
+ )(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
+ val fs = outputDir.getFileSystem(broadcastedConf.value.value)
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -94,7 +97,7 @@ private[spark] object CheckpointRDD extends Logging {
throw new IOException("Checkpoint failed: temporary path " +
tempOutputPath + " already exists")
}
- val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt
val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
@@ -121,10 +124,14 @@ private[spark] object CheckpointRDD extends Logging {
}
}
- def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
+ def readFromFile[T](
+ path: Path,
+ broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ context: TaskContext
+ ): Iterator[T] = {
val env = SparkEnv.get
- val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
- val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fs = path.getFileSystem(broadcastedConf.value.value)
+ val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
@@ -146,8 +153,10 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
- sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
+ val conf = SparkHadoopUtil.get.newConfiguration()
+ val fs = path.getFileSystem(conf)
+ val broadcastedConf = sc.broadcast(new SerializableWritable(conf))
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 911a002884..4ba4696fef 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -114,7 +114,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
map.changeValue(k, update)
}
- val ser = SparkEnv.get.serializerManager.get(serializerClass)
+ val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index c5de6362a9..98da35763b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -22,6 +22,7 @@ import java.io.{ObjectOutputStream, IOException}
import scala.collection.mutable
import scala.Some
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -68,7 +69,7 @@ case class CoalescedRDDPartition(
* @param maxPartitions number of desired partitions in the coalesced RDD
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
-class CoalescedRDD[T: ClassManifest](
+class CoalescedRDD[T: ClassTag](
@transient var prev: RDD[T],
maxPartitions: Int,
balanceSlack: Double = 0.10)
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..688c310ee9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,129 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(Double.NegativeInfinity,
+ Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
index c8900d1a93..a84e5f9fd8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
@@ -17,13 +17,14 @@
package org.apache.spark.rdd
-import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Partition, SparkContext, TaskContext}
/**
* An RDD that is empty, i.e. has no element in it.
*/
-class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
override def getPartitions: Array[Partition] = Array.empty
diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
index 5312dc0b59..e74c83b90b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class FilteredRDD[T: ClassManifest](
+private[spark] class FilteredRDD[T: ClassTag](
prev: RDD[T],
f: T => Boolean)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
index cbdf6d84c0..4d1878fc14 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
@@ -18,10 +18,11 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
private[spark]
-class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+class FlatMappedRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
index 829545d7b0..1a694475f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T])
extends RDD[Array[T]](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index aca0146884..8df8718f3b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.sql.{Connection, ResultSet}
+import scala.reflect.ClassTag
+
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.util.NextIterator
@@ -45,7 +47,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
-class JdbcRDD[T: ClassManifest](
+class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c4ea..db15baf503 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -18,20 +18,18 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: Iterator[T] => Iterator[U],
+ f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
- override val partitioner =
- if (preservesPartitioning) firstParent[T].partitioner else None
+ override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
- f(firstParent[T].iterator(split, context))
+ f(context, split.index, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
index e8be1c4816..8d7c288593 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{Partition, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
+class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 2662d48c84..73d15b9082 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -76,7 +76,7 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val attemptId = newTaskAttemptID(jobtrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 697be8b997..d5691f2267 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rdd
-import org.apache.spark.{RangePartitioner, Logging}
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, RangePartitioner}
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -25,9 +27,9 @@ import org.apache.spark.{RangePartitioner, Logging}
* use these functions. They will work with any key type that has a `scala.math.Ordered`
* implementation.
*/
-class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest,
- V: ClassManifest,
- P <: Product2[K, V] : ClassManifest](
+class OrderedRDDFunctions[K <% Ordered[K]: ClassTag,
+ V: ClassTag,
+ P <: Product2[K, V] : ClassTag](
self: RDD[P])
extends Logging with Serializable {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 53184a6b8a..2bf7c5b8d6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,6 +25,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.mapred._
import org.apache.hadoop.io.compress.CompressionCodec
@@ -39,18 +40,21 @@ import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
+
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.Aggregator
import org.apache.spark.Partitioner
import org.apache.spark.Partitioner.defaultPartitioner
+import org.apache.spark.util.SerializableHyperLogLog
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
+class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
extends Logging
with SparkHadoopMapReduceUtil
with Serializable {
@@ -207,6 +211,45 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
}
/**
+ * Return approximate number of distinct values for each key in this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. Uses the provided
+ * Partitioner to partition the output RDD.
+ */
+ def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
+ val createHLL = (v: V) => new SerializableHyperLogLog(new HyperLogLog(relativeSD)).add(v)
+ val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => hll.add(v)
+ val mergeHLL = (h1: SerializableHyperLogLog, h2: SerializableHyperLogLog) => h1.merge(h2)
+
+ combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.value.cardinality())
+ }
+
+ /**
+ * Return approximate number of distinct values for each key in this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. HashPartitions the
+ * output RDD into numPartitions.
+ *
+ */
+ def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): RDD[(K, Long)] = {
+ countApproxDistinctByKey(relativeSD, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Return approximate number of distinct values for each key this RDD.
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. The default value of
+ * relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
+ * level.
+ */
+ def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = {
+ countApproxDistinctByKey(relativeSD, defaultPartitioner(self))
+ }
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
@@ -415,7 +458,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, ws) =>
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
}
@@ -431,7 +474,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, w1s, w2s) =>
(vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])
}
@@ -488,15 +531,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] =
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
subtractByKey(other, new HashPartitioner(numPartitions))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
new SubtractedRDD[K, V, W](self, other, p)
/**
@@ -525,8 +568,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
- def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -535,16 +578,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* supplied codec.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
- path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
- def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -570,7 +613,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+ attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -589,13 +633,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
*/
- val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
jobCommitter.commitJob(jobTaskContext)
- jobCommitter.cleanupJob(jobTaskContext)
}
/**
@@ -685,7 +728,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
self.context.runJob(self, writeToFile _)
writer.commitJob()
- writer.cleanup()
}
/**
@@ -698,11 +740,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)
- private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
+ private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
- private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
+ private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass
}
-private[spark] object Manifests {
- val seqSeqManifest = classManifest[Seq[Seq[_]]]
+private[spark] object ClassTags {
+ val seqSeqClassTag = classTag[Seq[Seq[_]]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index cd96250389..09d0a8189d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -20,13 +20,15 @@ package org.apache.spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
+import scala.reflect.ClassTag
+
import org.apache.spark._
import java.io._
import scala.Serializable
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-private[spark] class ParallelCollectionPartition[T: ClassManifest](
+private[spark] class ParallelCollectionPartition[T: ClassTag](
var rddId: Long,
var slice: Int,
var values: Seq[T])
@@ -78,7 +80,7 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest](
}
}
-private[spark] class ParallelCollectionRDD[T: ClassManifest](
+private[spark] class ParallelCollectionRDD[T: ClassTag](
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
@@ -109,7 +111,7 @@ private object ParallelCollectionRDD {
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers.
*/
- def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
+ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
throw new IllegalArgumentException("Positive number of slices required")
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index 574dd4233f..ea8885b36e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext}
@@ -49,7 +51,7 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
-class PartitionPruningRDD[T: ClassManifest](
+class PartitionPruningRDD[T: ClassTag](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
@@ -69,6 +71,6 @@ object PartitionPruningRDD {
* when its type T is not known at compile time.
*/
def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = {
- new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest)
+ new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
new file mode 100644
index 0000000000..4c625d062e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+import java.io.{ObjectOutputStream, IOException}
+import org.apache.spark.{TaskContext, OneToOneDependency, SparkContext, Partition}
+
+
+/**
+ * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of corresponding partitions
+ * of parent RDDs.
+ */
+private[spark]
+class PartitionerAwareUnionRDDPartition(
+ @transient val rdds: Seq[RDD[_]],
+ val idx: Int
+ ) extends Partition {
+ var parents = rdds.map(_.partitions(idx)).toArray
+
+ override val index = idx
+ override def hashCode(): Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent partition at the time of task serialization
+ parents = rdds.map(_.partitions(index)).toArray
+ oos.defaultWriteObject()
+ }
+}
+
+/**
+ * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and
+ * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each
+ * will be unified to a single RDD with p partitions and the same partitioner. The preferred
+ * location for each partition of the unified RDD will be the most common preferred location
+ * of the corresponding partitions of the parent RDDs. For example, location of partition 0
+ * of the unified RDD will be where most of partition 0 of the parent RDDs are located.
+ */
+private[spark]
+class PartitionerAwareUnionRDD[T: ClassTag](
+ sc: SparkContext,
+ var rdds: Seq[RDD[T]]
+ ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
+ require(rdds.length > 0)
+ require(rdds.flatMap(_.partitioner).toSet.size == 1,
+ "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
+
+ override val partitioner = rdds.head.partitioner
+
+ override def getPartitions: Array[Partition] = {
+ val numPartitions = partitioner.get.numPartitions
+ (0 until numPartitions).map(index => {
+ new PartitionerAwareUnionRDDPartition(rdds, index)
+ }).toArray
+ }
+
+ // Get the location where most of the partitions of parent RDDs are located
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ logDebug("Finding preferred location for " + this + ", partition " + s.index)
+ val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+ val locations = rdds.zip(parentPartitions).flatMap {
+ case (rdd, part) => {
+ val parentLocations = currPrefLocs(rdd, part)
+ logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
+ parentLocations
+ }
+ }
+ val location = if (locations.isEmpty) {
+ None
+ } else {
+ // Find the location that maximum number of parent partitions prefer
+ Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
+ }
+ logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)
+ location.toSeq
+ }
+
+ override def compute(s: Partition, context: TaskContext): Iterator[T] = {
+ val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+ rdds.zip(parentPartitions).iterator.flatMap {
+ case (rdd, p) => rdd.iterator(p, context)
+ }
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+
+ // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
+ private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
+ rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index d5304ab0ae..1dbbe39898 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -24,6 +24,7 @@ import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import scala.reflect.ClassTag
import org.apache.spark.{SparkEnv, Partition, TaskContext}
import org.apache.spark.broadcast.Broadcast
@@ -33,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassManifest](
+class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
envVars: Map[String, String],
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 7623c44d88..3f41b66279 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -23,6 +23,9 @@ import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.reflect.{classTag, ClassTag}
+
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
@@ -30,6 +33,7 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.Partitioner._
import org.apache.spark.api.java.JavaRDD
@@ -38,7 +42,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{Utils, BoundedPriorityQueue}
+import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogLog}
import org.apache.spark.SparkContext._
import org.apache.spark._
@@ -69,7 +73,7 @@ import org.apache.spark._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](
+abstract class RDD[T: ClassTag](
@transient private var sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
@@ -78,6 +82,7 @@ abstract class RDD[T: ClassManifest](
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
+ private[spark] def conf = sc.conf
// =======================================================================
// Methods that should be implemented by subclasses of RDD
// =======================================================================
@@ -101,7 +106,7 @@ abstract class RDD[T: ClassManifest](
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
- val partitioner: Option[Partitioner] = None
+ @transient val partitioner: Option[Partitioner] = None
// =======================================================================
// Methods and fields available on all RDDs
@@ -114,7 +119,7 @@ abstract class RDD[T: ClassManifest](
val id: Int = sc.newRddId()
/** A friendly name for this RDD */
- var name: String = null
+ @transient var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
@@ -123,7 +128,7 @@ abstract class RDD[T: ClassManifest](
}
/** User-defined generator of this RDD*/
- var generator = Utils.getCallSiteInfo.firstUserClass
+ @transient var generator = Utils.getCallSiteInfo.firstUserClass
/** Reset generator*/
def setGenerator(_generator: String) = {
@@ -243,13 +248,13 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
+ def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
+ def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] =
new FlatMappedRDD(this, sc.clean(f))
/**
@@ -374,25 +379,25 @@ abstract class RDD[T: ClassManifest](
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
- def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
+ def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] =
groupBy[K](f, defaultPartitioner(this))
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
groupBy(f, new HashPartitioner(numPartitions))
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
@@ -408,7 +413,6 @@ abstract class RDD[T: ClassManifest](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
-
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
@@ -440,29 +444,31 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](
+ def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[U: ClassManifest](
+ def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
- new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
*/
- def mapPartitionsWithContext[U: ClassManifest](
+ def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -470,7 +476,7 @@ abstract class RDD[T: ClassManifest](
* of the original partition.
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
- def mapPartitionsWithSplit[U: ClassManifest](
+ def mapPartitionsWithSplit[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
mapPartitionsWithIndex(f, preservesPartitioning)
}
@@ -480,14 +486,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest]
+ def mapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.map(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -495,14 +500,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest]
+ def flatMapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -510,12 +514,11 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ mapPartitionsWithIndex { (index, iter) =>
+ val a = constructA(index)
iter.map(t => {f(t, a); t})
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+ }.foreach(_ => {})
}
/**
@@ -523,12 +526,11 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.filter(t => p(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+ }, preservesPartitioning = true)
}
/**
@@ -537,7 +539,7 @@ abstract class RDD[T: ClassManifest](
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
- def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
@@ -545,32 +547,27 @@ abstract class RDD[T: ClassManifest](
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
- def zipPartitions[B: ClassManifest, V: ClassManifest]
- (rdd2: RDD[B], preservesPartitioning: Boolean)
- (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning)
-
- def zipPartitions[B: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, V: ClassTag]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
- def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean)
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning)
- def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C])
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false)
- def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean)
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning)
- def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false)
@@ -608,7 +605,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD that contains all matching values by applying `f`.
*/
- def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = {
filter(f.isDefinedAt).map(f)
}
@@ -698,7 +695,7 @@ abstract class RDD[T: ClassManifest](
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*/
- def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
@@ -747,7 +744,7 @@ abstract class RDD[T: ClassManifest](
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
// TODO: This should perhaps be distributed by default.
@@ -778,7 +775,7 @@ abstract class RDD[T: ClassManifest](
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
@@ -794,6 +791,19 @@ abstract class RDD[T: ClassManifest](
}
/**
+ * Return approximate number of distinct elements in the RDD.
+ *
+ * The accuracy of approximation can be controlled through the relative standard deviation
+ * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
+ * more accurate counts but increase the memory footprint and vise versa. The default value of
+ * relativeSD is 0.05.
+ */
+ def countApproxDistinct(relativeSD: Double = 0.05): Long = {
+ val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
+ aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality()
+ }
+
+ /**
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
@@ -943,14 +953,14 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.formatSparkCallSite
+ @transient private[spark] val origin = sc.getCallSite
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ private[spark] def elementClassTag: ClassTag[T] = classTag[T]
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
/** Returns the first parent RDD */
- protected[spark] def firstParent[U: ClassManifest] = {
+ protected[spark] def firstParent[U: ClassTag] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}
@@ -958,7 +968,7 @@ abstract class RDD[T: ClassManifest](
def context = sc
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
- private var doCheckpointCalled = false
+ @transient private var doCheckpointCalled = false
/**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
@@ -1012,7 +1022,7 @@ abstract class RDD[T: ClassManifest](
origin)
def toJavaRDD() : JavaRDD[T] = {
- new JavaRDD(this)(elementClassManifest)
+ new JavaRDD(this)(elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 6009a41570..bc688110f4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.{Partition, SparkException, Logging}
+import org.apache.spark.{SerializableWritable, Partition, SparkException, Logging}
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
/**
@@ -38,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration {
* manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
extends Logging with Serializable {
import CheckpointState._
@@ -83,14 +85,21 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
}
// Save to file, and reload it as an RDD
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
+ val broadcastedConf = rdd.context.broadcast(
+ new SerializableWritable(rdd.context.hadoopConfiguration))
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
+ if (newRDD.partitions.size != rdd.partitions.size) {
+ throw new SparkException(
+ "Checkpoint RDD " + newRDD + "("+ newRDD.partitions.size + ") has different " +
+ "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")")
+ }
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
@@ -99,8 +108,8 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
- logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
}
+ logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
// Get preferred location of a split after checkpointing
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index 2c5253ae30..d433670cc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
import java.util.Random
import cern.jet.random.Poisson
@@ -29,9 +30,9 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition
override val index: Int = prev.index
}
-class SampledRDD[T: ClassManifest](
+class SampledRDD[T: ClassTag](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 5fe4676029..2d1bd5b481 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -14,9 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.rdd
+import scala.reflect.{ ClassTag, classTag}
+
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.io.compress.CompressionCodec
@@ -32,15 +33,15 @@ import org.apache.spark.Logging
*
* Import `org.apache.spark.SparkContext._` at the top of their program to use these functions.
*/
-class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
+class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
self: RDD[(K, V)])
extends Logging
with Serializable {
- private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
+ private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = {
val c = {
- if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
- classManifest[T].erasure
+ if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) {
+ classTag[T].runtimeClass
} else {
// We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index a5d751a7bd..0ccb309d0d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -17,8 +17,10 @@
package org.apache.spark.rdd
-import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency,
+ SparkEnv, TaskContext}
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -32,7 +34,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam K the key class.
* @tparam V the value class.
*/
-class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
+class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
@transient var prev: RDD[P],
part: Partitioner)
extends RDD[P](prev.context, Nil) {
@@ -57,7 +59,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
- SparkEnv.get.serializerManager.get(serializerClass))
+ SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 7af4d803e7..4f90c7d3d6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -18,8 +18,11 @@
package org.apache.spark.rdd
import java.util.{HashMap => JHashMap}
+
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.Partitioner
import org.apache.spark.Dependency
import org.apache.spark.TaskContext
@@ -45,7 +48,7 @@ import org.apache.spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[_ <: Product2[K, V]],
@transient var rdd2: RDD[_ <: Product2[K, W]],
part: Partitioner)
@@ -90,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val serializer = SparkEnv.get.serializerManager.get(serializerClass)
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index ae8a9f36a6..08a41ac558 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -18,10 +18,13 @@
package org.apache.spark.rdd
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
-private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
extends Partition {
var split: Partition = rdd.partitions(splitIndex)
@@ -40,7 +43,7 @@ private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], spl
}
}
-class UnionRDD[T: ClassManifest](
+class UnionRDD[T: ClassTag](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index a97d2a01c8..83be3c6eb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
private[spark] class ZippedPartitionsPartition(
idx: Int,
@@ -38,7 +39,7 @@ private[spark] class ZippedPartitionsPartition(
}
}
-abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
@@ -71,7 +72,7 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
}
}
-class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
@@ -92,7 +93,7 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]
}
class ZippedPartitionsRDD3
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
@@ -117,7 +118,7 @@ class ZippedPartitionsRDD3
}
class ZippedPartitionsRDD4
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
index 567b67dfee..fb5b070c18 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
@@ -18,10 +18,12 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
-private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
+private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
@@ -42,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
}
}
-class ZippedRDD[T: ClassManifest, U: ClassManifest](
+class ZippedRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 42bb3884c8..043e01dbfb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -21,9 +21,11 @@ import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
-import akka.actor._
-import akka.util.duration._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.rdd.RDD
@@ -104,36 +106,16 @@ class DAGScheduler(
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50L
+ val RESUBMIT_TIMEOUT = 50.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
- override def preStart() {
- context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) {
- if (failed.size > 0) {
- resubmitFailedStages()
- }
- }
- }
-
- /**
- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
- * events and responds by launching tasks. This runs in a dedicated thread and receives events
- * via the eventQueue.
- */
- def receive = {
- case event: DAGSchedulerEvent =>
- logDebug("Got event of type " + event.getClass.getName)
+ // Warns the user if a stage contains a task with size greater than this value (in KB)
+ val TASK_SIZE_TO_WARN = 100
- if (!processEvent(event))
- submitWaitingStages()
- else
- context.stop(self)
- }
- }))
+ private var eventProcessActor: ActorRef = _
private[scheduler] val nextJobId = new AtomicInteger(0)
@@ -141,9 +123,13 @@ class DAGScheduler(
private val nextStageId = new AtomicInteger(0)
- private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
- private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
+
+ private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -166,13 +152,66 @@ class DAGScheduler(
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ // Missing tasks from each stage
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
+ val metadataCleaner = new MetadataCleaner(
+ MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf)
+
+ /**
+ * Starts the event processing actor. The actor has two responsibilities:
+ *
+ * 1. Waits for events like job submission, task finished, task failure etc., and calls
+ * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them.
+ * 2. Schedules a periodical task to resubmit failed stages.
+ *
+ * NOTE: the actor cannot be started in the constructor, because the periodical task references
+ * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus
+ * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed.
+ */
+ def start() {
+ eventProcessActor = env.actorSystem.actorOf(Props(new Actor {
+ /**
+ * A handle to the periodical task, used to cancel the task when the actor is stopped.
+ */
+ var resubmissionTask: Cancellable = _
+
+ override def preStart() {
+ import context.dispatcher
+ /**
+ * A message is sent to the actor itself periodically to remind the actor to resubmit failed
+ * stages. In this way, stage resubmission can be done within the same thread context of
+ * other event processing logic to avoid unnecessary synchronization overhead.
+ */
+ resubmissionTask = context.system.scheduler.schedule(
+ RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT, self, ResubmitFailedStages)
+ }
+
+ /**
+ * The main event loop of the DAG scheduler.
+ */
+ def receive = {
+ case event: DAGSchedulerEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+
+ /**
+ * All events are forwarded to `processEvent()`, so that the event processing logic can
+ * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite`
+ * for details.
+ */
+ if (!processEvent(event)) {
+ submitWaitingStages()
+ } else {
+ resubmissionTask.cancel()
+ context.stop(self)
+ }
+ }
+ }))
+ }
def addSparkListener(listener: SparkListener) {
listenerBus.addListener(listener)
@@ -202,16 +241,18 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
+ val stage =
+ newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
}
/**
- * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
- * as a result stage for the final RDD used directly in an action. The stage will also be
- * associated with the provided jobId.
+ * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation
+ * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided
+ * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage
+ * directly.
*/
private def newStage(
rdd: RDD[_],
@@ -221,21 +262,45 @@ class DAGScheduler(
callSite: Option[String] = None)
: Stage =
{
- if (shuffleDep != None) {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of partitions is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
- }
val id = nextStageId.getAndIncrement()
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
+ updateJobIdStageIdMaps(jobId, stage)
stageToInfos(stage) = new StageInfo(stage)
stage
}
/**
+ * Create a shuffle map Stage for the given RDD. The stage will also be associated with the
+ * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is
+ * present in the MapOutputTracker, then the number and location of available outputs are
+ * recovered from the MapOutputTracker
+ */
+ private def newOrUsedStage(
+ rdd: RDD[_],
+ numTasks: Int,
+ shuffleDep: ShuffleDependency[_,_],
+ jobId: Int,
+ callSite: Option[String] = None)
+ : Stage =
+ {
+ val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
+ if (mapOutputTracker.has(shuffleDep.shuffleId)) {
+ val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
+ val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
+ for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i))
+ stage.numAvailableOutputs = locs.size
+ } else {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of partitions is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
+ }
+ stage
+ }
+
+ /**
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided jobId if they haven't already been created with a lower jobId.
*/
@@ -287,6 +352,94 @@ class DAGScheduler(
}
/**
+ * Registers the given jobId among the jobs that need the given stage and
+ * all of that stage's ancestors.
+ */
+ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
+ def updateJobIdStageIdMapsList(stages: List[Stage]) {
+ if (!stages.isEmpty) {
+ val s = stages.head
+ stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
+ jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
+ val parents = getParentStages(s.rdd, jobId)
+ val parentsWithoutThisJobId = parents.filter(p =>
+ !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+ updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
+ }
+ }
+ updateJobIdStageIdMapsList(List(stage))
+ }
+
+ /**
+ * Removes job and any stages that are not needed by any other job. Returns the set of ids for
+ * stages that were removed. The associated tasks for those stages need to be cancelled if we
+ * got here via job cancellation.
+ */
+ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
+ val registeredStages = jobIdToStageIds(jobId)
+ val independentStages = new HashSet[Int]()
+ if (registeredStages.isEmpty) {
+ logError("No stages registered for job " + jobId)
+ } else {
+ stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
+ case (stageId, jobSet) =>
+ if (!jobSet.contains(jobId)) {
+ logError(
+ "Job %d not registered for stage %d even though that stage was registered for the job"
+ .format(jobId, stageId))
+ } else {
+ def removeStage(stageId: Int) {
+ // data structures based on Stage
+ stageIdToStage.get(stageId).foreach { s =>
+ if (running.contains(s)) {
+ logDebug("Removing running stage %d".format(stageId))
+ running -= s
+ }
+ stageToInfos -= s
+ shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId =>
+ shuffleToMapStage.remove(shuffleId))
+ if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
+ logDebug("Removing pending status for stage %d".format(stageId))
+ }
+ pendingTasks -= s
+ if (waiting.contains(s)) {
+ logDebug("Removing stage %d from waiting set.".format(stageId))
+ waiting -= s
+ }
+ if (failed.contains(s)) {
+ logDebug("Removing stage %d from failed set.".format(stageId))
+ failed -= s
+ }
+ }
+ // data structures based on StageId
+ stageIdToStage -= stageId
+ stageIdToJobIds -= stageId
+
+ logDebug("After removal of stage %d, remaining stages = %d"
+ .format(stageId, stageIdToStage.size))
+ }
+
+ jobSet -= jobId
+ if (jobSet.isEmpty) { // no other job needs this stage
+ independentStages += stageId
+ removeStage(stageId)
+ }
+ }
+ }
+ }
+ independentStages.toSet
+ }
+
+ private def jobIdToStageIdsRemove(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to remove unregistered job " + jobId)
+ } else {
+ removeJobAndIndependentStages(jobId)
+ jobIdToStageIds -= jobId
+ }
+ }
+
+ /**
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to cancel the job.
*/
@@ -315,11 +468,12 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ eventProcessActor ! JobSubmitted(
+ jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter
}
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -350,7 +504,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+ eventProcessActor ! JobSubmitted(
+ jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -375,13 +530,25 @@ class DAGScheduler(
}
/**
- * Process one event retrieved from the event queue.
- * Returns true if we should stop the event loop.
+ * Process one event retrieved from the event processing actor.
+ *
+ * @param event The event to be processed.
+ * @return `true` if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
+ var finalStage: Stage = null
+ try {
+ // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD
+ // whose underlying HDFS files have been deleted.
+ finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
+ } catch {
+ case e: Exception =>
+ logWarning("Creating new stage failed due to exception - job: " + jobId, e)
+ listener.jobFailed(e)
+ return false
+ }
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -391,37 +558,32 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
+ listenerBus.post(SparkListenerJobStart(job, Array(), properties))
runLocally(job)
} else {
- listenerBus.post(SparkListenerJobStart(job, properties))
idToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
+ listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties))
submitStage(finalStage)
}
case JobCancelled(jobId) =>
- // Cancel a job: find all the running stages that are linked to this job, and cancel them.
- running.filter(_.jobId == jobId).foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ handleJobCancellation(jobId)
case JobGroupCancelled(groupId) =>
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
- val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
- .map(_.jobId)
- if (!jobIds.isEmpty) {
- running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
- }
+ val activeInGroup = activeJobs.filter(activeJob =>
+ groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ val jobIds = activeInGroup.map(_.jobId)
+ jobIds.foreach { handleJobCancellation }
case AllJobsCancelled =>
// Cancel all running jobs.
- running.foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ running.map(_.jobId).foreach { handleJobCancellation }
+ activeJobs.clear() // These should already be empty by this point,
+ idToActiveJob.clear() // but just in case we lost track of some jobs...
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
@@ -430,6 +592,19 @@ class DAGScheduler(
handleExecutorLost(execId)
case BeginEvent(task, taskInfo) =>
+ for (
+ job <- idToActiveJob.get(task.stageId);
+ stage <- stageIdToStage.get(task.stageId);
+ stageInfo <- stageToInfos.get(stage)
+ ) {
+ if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 &&
+ !stageInfo.emittedTaskSizeWarning) {
+ stageInfo.emittedTaskSizeWarning = true
+ logWarning(("Stage %d (%s) contains a task of very large " +
+ "size (%d KB). The maximum recommended task size is %d KB.").format(
+ task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+ }
+ }
listenerBus.post(SparkListenerTaskStart(task, taskInfo))
case GettingResultEvent(task, taskInfo) =>
@@ -440,7 +615,12 @@ class DAGScheduler(
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
- abortStage(stageIdToStage(taskSet.stageId), reason)
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) }
+
+ case ResubmitFailedStages =>
+ if (failed.size > 0) {
+ resubmitFailedStages()
+ }
case StopDAGScheduler =>
// Cancel any active jobs
@@ -502,6 +682,7 @@ class DAGScheduler(
// Broken out for easier testing in DAGSchedulerSuite.
protected def runLocallyWithinThread(job: ActiveJob) {
+ var jobResult: JobResult = JobSucceeded
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
@@ -516,31 +697,59 @@ class DAGScheduler(
}
} catch {
case e: Exception =>
+ jobResult = JobFailed(e, Some(job.finalStage))
job.listener.jobFailed(e)
+ } finally {
+ val s = job.finalStage
+ stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
+ stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
+ stageToInfos -= s // completion events or stage abort
+ jobIdToStageIds -= job.jobId
+ listenerBus.post(SparkListenerJobEnd(job, jobResult))
+ }
+ }
+
+ /** Finds the earliest-created active job that needs the stage */
+ // TODO: Probably should actually find among the active jobs that need this
+ // stage the one with the highest priority (highest-priority pool, earliest created).
+ // That should take care of at least part of the priority inversion problem with
+ // cross-job dependencies.
+ private def activeJobForStage(stage: Stage): Option[Int] = {
+ if (stageIdToJobIds.contains(stage.id)) {
+ val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
+ jobsThatUseStage.find(idToActiveJob.contains(_))
+ } else {
+ None
}
}
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
- logDebug("submitStage(" + stage + ")")
- if (!waiting(stage) && !running(stage) && !failed(stage)) {
- val missing = getMissingParentStages(stage).sortBy(_.id)
- logDebug("missing: " + missing)
- if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
+ val jobId = activeJobForStage(stage)
+ if (jobId.isDefined) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
+ submitMissingTasks(stage, jobId.get)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
}
- waiting += stage
}
+ } else {
+ abortStage(stage, "No active job for stage " + stage.id)
}
}
+
/** Called when stage's parents are available and we can now do its task. */
- private def submitMissingTasks(stage: Stage) {
+ private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -561,7 +770,7 @@ class DAGScheduler(
}
}
- val properties = if (idToActiveJob.contains(stage.jobId)) {
+ val properties = if (idToActiveJob.contains(jobId)) {
idToActiveJob(stage.jobId).properties
} else {
//this stage will be assigned to "default" pool
@@ -619,7 +828,7 @@ class DAGScheduler(
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
- listenerBus.post(StageCompleted(stageToInfos(stage)))
+ listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
running -= stage
}
event.reason match {
@@ -643,6 +852,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ jobIdToStageIdsRemove(job.jobId)
listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
}
job.listener.taskSucceeded(rt.outputId, event.result)
@@ -679,7 +889,7 @@ class DAGScheduler(
changeEpoch = true)
}
clearCacheLocs()
- if (stage.outputLocs.count(_ == Nil) != 0) {
+ if (stage.outputLocs.exists(_ == Nil)) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " (" + stage.name +
@@ -696,9 +906,12 @@ class DAGScheduler(
}
waiting --= newlyRunnable
running ++= newlyRunnable
- for (stage <- newlyRunnable.sortBy(_.id)) {
+ for {
+ stage <- newlyRunnable.sortBy(_.id)
+ jobId <- activeJobForStage(stage)
+ } {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
- submitMissingTasks(stage)
+ submitMissingTasks(stage, jobId)
}
}
}
@@ -782,21 +995,42 @@ class DAGScheduler(
}
}
+ private def handleJobCancellation(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to cancel unregistered job " + jobId)
+ } else {
+ val independentStages = removeJobAndIndependentStages(jobId)
+ independentStages.foreach { taskSched.cancelTasks }
+ val error = new SparkException("Job %d cancelled".format(jobId))
+ val job = idToActiveJob(jobId)
+ job.listener.jobFailed(error)
+ jobIdToStageIds -= jobId
+ activeJobs -= job
+ idToActiveJob -= jobId
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage))))
+ }
+ }
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
+ if (!stageIdToStage.contains(failedStage.id)) {
+ // Skip all the actions if the stage has been removed.
+ return
+ }
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
+ jobIdToStageIdsRemove(job.jobId)
idToActiveJob -= resultStage.jobId
activeJobs -= job
resultStageToJob -= resultStage
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -867,25 +1101,24 @@ class DAGScheduler(
}
private def cleanup(cleanupTime: Long) {
- var sizeBefore = stageIdToStage.size
- stageIdToStage.clearOldValues(cleanupTime)
- logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size)
-
- sizeBefore = shuffleToMapStage.size
- shuffleToMapStage.clearOldValues(cleanupTime)
- logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
-
- sizeBefore = pendingTasks.size
- pendingTasks.clearOldValues(cleanupTime)
- logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
-
- sizeBefore = stageToInfos.size
- stageToInfos.clearOldValues(cleanupTime)
- logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
+ Map(
+ "stageIdToStage" -> stageIdToStage,
+ "shuffleToMapStage" -> shuffleToMapStage,
+ "pendingTasks" -> pendingTasks,
+ "stageToInfos" -> stageToInfos,
+ "jobIdToStageIds" -> jobIdToStageIds,
+ "stageIdToJobIds" -> stageIdToJobIds).
+ foreach { case(s, t) => {
+ val sizeBefore = t.size
+ t.clearOldValues(cleanupTime)
+ logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
+ }}
}
def stop() {
- eventProcessActor ! StopDAGScheduler
+ if (eventProcessActor != null) {
+ eventProcessActor ! StopDAGScheduler
+ }
metadataCleaner.cancel()
taskSched.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 708d221d60..add1187613 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -65,12 +65,13 @@ private[scheduler] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[scheduler]
-case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
+
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 5077b2b48b..2bc43a9186 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.executor.ExecutorExitCode
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 1791ee660d..90eb8a747f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -32,7 +32,7 @@ import scala.collection.JavaConversions._
/**
* Parses and holds information about inputFormat (and files) specified as a parameter.
*/
-class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
+class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
val path: String) extends Logging {
var mapreduceInputFormat: Boolean = false
@@ -40,7 +40,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
validate()
- override def toString(): String = {
+ override def toString: String = {
"InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
}
@@ -125,7 +125,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
}
private def findPreferredLocations(): Set[SplitInfo] = {
- logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
+ logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
", inputFormatClazz : " + inputFormatClazz)
if (mapreduceInputFormat) {
return prefLocsFromMapreduceInputFormat()
@@ -143,14 +143,14 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
object InputFormatInfo {
/**
Computes the preferred locations based on input(s) and returned a location to block map.
- Typical use of this method for allocation would follow some algo like this
- (which is what we currently do in YARN branch) :
+ Typical use of this method for allocation would follow some algo like this:
+
a) For each host, count number of splits hosted on that host.
b) Decrement the currently allocated containers on that host.
c) Compute rack info for each host and update rack -> count map based on (b).
d) Allocate nodes based on (c)
- e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
- (even if data locality on that is very high) : this is to prevent fragility of job if a single
+ e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ (even if data locality on that is very high) : this is to prevent fragility of job if a single
(or small set of) hosts go down.
go to (a) until required nodes are allocated.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 60927831a1..f8fa5a9f7a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -297,7 +297,7 @@ class JobLogger(val user: String, val logDirName: String)
* When stage is completed, record stage completion status
* @param stageCompleted Stage completed event
*/
- override def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
stageCompleted.stage.stageId))
}
@@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String)
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
case _ =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 58f238d8cf..b026f860a8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -31,6 +31,7 @@ private[spark] class JobWaiter[T](
private var finishedTasks = 0
// Is the job as a whole finished (succeeded or failed)?
+ @volatile
private var _jobFinished = totalTasks == 0
def jobFinished = _jobFinished
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 596f9adde9..1791242215 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -117,8 +117,4 @@ private[spark] class Pool(
parent.decreaseRunningTasks(taskNum)
}
}
-
- override def hasPendingTasks(): Boolean = {
- schedulableQueue.exists(_.hasPendingTasks())
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 310ec62ca8..28f3ba53b8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -32,7 +32,9 @@ private[spark] object ResultTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
+ // TODO: This object shouldn't have global variables
+ val metadataCleaner = new MetadataCleaner(
+ MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index 1c7ea2dccc..d573e125a3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -42,5 +42,4 @@ private[spark] trait Schedulable {
def executorLost(executorId: String, host: String): Unit
def checkSpeculatableTasks(): Boolean
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
- def hasPendingTasks(): Boolean
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 356fe56bf3..3cf995ea74 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler
import java.io.{FileInputStream, InputStream}
import java.util.{NoSuchElementException, Properties}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import scala.xml.XML
@@ -49,10 +49,10 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
}
}
-private[spark] class FairSchedulableBuilder(val rootPool: Pool)
+private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
extends SchedulableBuilder with Logging {
- val schedulerAllocFile = Option(System.getProperty("spark.scheduler.allocation.file"))
+ val schedulerAllocFile = conf.getOption("spark.scheduler.allocation.file")
val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml"
val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool"
val DEFAULT_POOL_NAME = "default"
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 5367218faa..02bdbba825 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -15,12 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.SparkContext
/**
- * A backend interface for cluster scheduling systems that allows plugging in different ones under
+ * A backend interface for scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
@@ -31,7 +31,4 @@ private[spark] trait SchedulerBackend {
def defaultParallelism(): Int
def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
-
- // Memory used by each executor (in megabytes)
- protected val executorMemory: Int = SparkContext.executorMemoryRequested
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
index 0a786deb16..3832ee7ff6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
@@ -22,7 +22,7 @@ package org.apache.spark.scheduler
* to order tasks amongst a Schedulable's sub-queues
* "NONE" is used when the a Schedulable has no sub-queues.
*/
-object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") {
+object SchedulingMode extends Enumeration {
type SchedulingMode = Value
val FAIR,FIFO,NONE = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 0f2deb4bcb..a37ead5632 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -37,7 +37,9 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues)
+ // TODO: This object shouldn't have global variables
+ val metadataCleaner = new MetadataCleaner(
+ MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -152,7 +154,7 @@ private[spark] class ShuffleMapTask(
try {
// Obtain all the block writers for shuffle blocks.
- val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index a35081f7b1..627995c826 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -27,7 +27,7 @@ sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
+case class SparkListenerStageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
@@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult(
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
-case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null)
extends SparkListenerEvents
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
@@ -47,7 +47,7 @@ trait SparkListener {
/**
* Called when a stage is completed, with information on the completed stage
*/
- def onStageCompleted(stageCompleted: StageCompleted) { }
+ def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
/**
* Called when a stage is submitted
@@ -63,7 +63,7 @@ trait SparkListener {
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
@@ -86,7 +86,7 @@ trait SparkListener {
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
- override def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stage)
@@ -119,20 +119,24 @@ object StatsReportListener extends Logging {
val probabilities = percentiles.map{_ / 100.0}
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
- def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
+ def extractDoubleDistribution(stage: SparkListenerStageCompleted,
+ getMetric: (TaskInfo,TaskMetrics) => Option[Double])
+ : Option[Distribution] = {
Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
//is there some way to setup the types that I can get rid of this completely?
- def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = {
+ def extractLongDistribution(stage: SparkListenerStageCompleted,
+ getMetric: (TaskInfo,TaskMetrics) => Option[Long])
+ : Option[Distribution] = {
extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
}
def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
val stats = d.statCounter
- logInfo(heading + stats)
val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(heading + stats)
logInfo(percentilesHeader)
logInfo("\t" + quantiles.mkString("\t"))
}
@@ -147,12 +151,12 @@ object StatsReportListener extends Logging {
}
def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double])
- (implicit stage: StageCompleted) {
+ (implicit stage: SparkListenerStageCompleted) {
showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
}
def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
- (implicit stage: StageCompleted) {
+ (implicit stage: SparkListenerStageCompleted) {
showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
}
@@ -169,12 +173,10 @@ object StatsReportListener extends Logging {
}
def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
- (implicit stage: StageCompleted) {
+ (implicit stage: SparkListenerStageCompleted) {
showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
}
-
-
val seconds = 1000L
val minutes = seconds * 60
val hours = minutes * 60
@@ -198,7 +200,6 @@ object StatsReportListener extends Logging {
}
-
case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index d5824e7954..e7defd768b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -41,7 +41,7 @@ private[spark] class SparkListenerBus() extends Logging {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
- case stageCompleted: StageCompleted =>
+ case stageCompleted: SparkListenerStageCompleted =>
sparkListeners.foreach(_.onStageCompleted(stageCompleted))
case jobStart: SparkListenerJobStart =>
sparkListeners.foreach(_.onJobStart(jobStart))
@@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging {
return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 93599dfdc8..e9f2198a00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -33,4 +33,5 @@ class StageInfo(
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
+ var emittedTaskSizeWarning = false
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4bae26f3a6..3c22edd524 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -46,6 +46,8 @@ class TaskInfo(
var failed = false
+ var serializedSize: Int = 0
+
def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
index 47b0f387aa..35de13c385 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
@@ -18,9 +18,7 @@
package org.apache.spark.scheduler
-private[spark] object TaskLocality
- extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY")
-{
+private[spark] object TaskLocality extends Enumeration {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 7e468d0d67..e80cc6b0f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -35,18 +35,15 @@ case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Se
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
-class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
extends TaskResult[T] with Externalizable {
- def this() = this(null.asInstanceOf[T], null, null)
+ def this() = this(null.asInstanceOf[ByteBuffer], null, null)
override def writeExternal(out: ObjectOutput) {
- val objectSer = SparkEnv.get.serializer.newInstance()
- val bb = objectSer.serialize(value)
-
- out.writeInt(bb.remaining())
- Utils.writeByteBuffer(bb, out)
+ out.writeInt(valueBytes.remaining);
+ Utils.writeByteBuffer(valueBytes, out)
out.writeInt(accumUpdates.size)
for ((key, value) <- accumUpdates) {
@@ -58,12 +55,10 @@ class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var me
override def readExternal(in: ObjectInput) {
- val objectSer = SparkEnv.get.serializer.newInstance()
-
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
in.readFully(byteVal)
- value = objectSer.deserialize(ByteBuffer.wrap(byteVal))
+ valueBytes = ByteBuffer.wrap(byteVal)
val numUpdates = in.readInt
if (numUpdates == 0) {
@@ -76,4 +71,9 @@ class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var me
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
}
+
+ def value(): T = {
+ val resultSer = SparkEnv.get.serializer.newInstance()
+ return resultSer.deserialize(valueBytes)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 2064d97b49..e22b1e53e8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -15,23 +15,23 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
-private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
extends Logging {
- private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+
+ private val THREADS = sparkEnv.conf.get("spark.resultGetter.threads", "4").toInt
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
THREADS, "Result resolver thread")
@@ -42,7 +42,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
}
def enqueueSuccessfulTask(
- taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
getTaskResultExecutor.execute(new Runnable {
override def run() {
try {
@@ -71,14 +71,14 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
- case ex =>
+ case ex: Throwable =>
taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
}
}
})
}
- def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+ def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason: Option[TaskEndReason] = None
getTaskResultExecutor.execute(new Runnable {
@@ -95,7 +95,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
val loader = Thread.currentThread.getContextClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex => {}
+ case ex: Throwable => {}
}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 10e0478108..17b6d97e90 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -20,11 +20,12 @@ package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
- * Each TaskScheduler schedulers task for a single SparkContext.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler.
+ * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * for a single SparkContext. These schedulers get sets of tasks submitted to them from the
+ * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
+ * them, retrying if there are failures, and mitigating stragglers. They return events to the
+ * DAGScheduler.
*/
private[spark] trait TaskScheduler {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index c1e65a3c48..0c8ed62759 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
@@ -24,41 +24,46 @@ import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-
-import akka.util.duration._
+import scala.concurrent.duration._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
- * initialize() and start(), then submit task sets through the runTasks method.
- *
- * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
+ * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
* It handles common logic, like determining a scheduling order across jobs, waking up to launch
* speculative tasks, etc.
*
+ * Clients should first call initialize() and start(), then submit task sets through the
+ * runTasks method.
+ *
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* SchedulerBackends sycnchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
-private[spark] class ClusterScheduler(val sc: SparkContext)
- extends TaskScheduler
- with Logging
+private[spark] class TaskSchedulerImpl(
+ val sc: SparkContext,
+ val maxTaskFailures: Int,
+ isLocal: Boolean = false)
+ extends TaskScheduler with Logging
{
+ def this(sc: SparkContext) = this(sc, sc.conf.get("spark.task.maxFailures", "4").toInt)
+
+ val conf = sc.conf
+
// How often to check for speculative tasks
- val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+ val SPECULATION_INTERVAL = conf.get("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
- val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+ val STARVATION_TIMEOUT = conf.get("spark.starvation.timeout", "15000").toLong
- // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+ // TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
+ val activeTaskSets = new HashMap[String, TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -91,7 +96,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
var rootPool: Pool = null
// default scheduler is FIFO
val schedulingMode: SchedulingMode = SchedulingMode.withName(
- System.getProperty("spark.scheduler.mode", "FIFO"))
+ conf.get("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
@@ -100,8 +105,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
this.dagScheduler = dagScheduler
}
- def initialize(context: SchedulerBackend) {
- backend = context
+ def initialize(backend: SchedulerBackend) {
+ this.backend = backend
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
@@ -109,7 +114,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
- new FairSchedulableBuilder(rootPool)
+ new FairSchedulableBuilder(rootPool, conf)
}
}
schedulableBuilder.buildPools()
@@ -120,9 +125,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def start() {
backend.start()
- if (System.getProperty("spark.speculation", "false").toBoolean) {
+ if (!isLocal && conf.get("spark.speculation", "false").toBoolean) {
logInfo("Starting speculative execution thread")
-
+ import sc.env.actorSystem.dispatcher
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
SPECULATION_INTERVAL milliseconds) {
checkSpeculatableTasks()
@@ -134,12 +139,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new ClusterTaskSetManager(this, taskSet)
+ val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- if (!hasReceivedTask) {
+ if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
@@ -173,7 +178,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.killTask(tid, execId)
}
}
- tsm.error("Stage %d was cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
+ tsm.removeAllRunningTasks()
+ taskSetFinished(tsm)
}
}
@@ -278,7 +285,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
case None =>
- logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+ .format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@@ -291,19 +299,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]) = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}
def handleFailedTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
@@ -321,7 +329,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
try {
- manager.error(message)
+ manager.abort(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
@@ -351,7 +359,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
-
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
@@ -363,13 +370,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- // Check for pending tasks in all our active jobs.
- def hasPendingTasks: Boolean = {
- synchronized {
- rootPool.hasPendingTasks()
- }
- }
-
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
@@ -428,7 +428,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
-object ClusterScheduler {
+private[spark] object TaskSchedulerImpl {
/**
* Used to balance containers across hosts.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 90f6bcefac..6dd1469d8f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -17,32 +17,693 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
+import java.io.NotSerializableException
+import java.util.Arrays
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock}
+
/**
- * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
- * each task and is responsible for retries on failure and locality. The main interfaces to it
- * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
- * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*
- * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
- * (e.g. its event handlers). It should not be called from other threads.
+ * @param sched the ClusterScheduler associated with the TaskSetManager
+ * @param taskSet the TaskSet to manage scheduling for
+ * @param maxTaskFailures if any particular task fails more than this number of times, the entire
+ * task set will be aborted
*/
-private[spark] trait TaskSetManager extends Schedulable {
- def schedulableQueue = null
-
- def schedulingMode = SchedulingMode.NONE
-
- def taskSet: TaskSet
+private[spark] class TaskSetManager(
+ sched: TaskSchedulerImpl,
+ val taskSet: TaskSet,
+ val maxTaskFailures: Int,
+ clock: Clock = SystemClock)
+ extends Schedulable with Logging
+{
+ val conf = sched.sc.conf
+
+ // CPUs to request per task
+ val CPUS_PER_TASK = conf.get("spark.task.cpus", "1").toInt
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = conf.get("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = conf.get("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val env = SparkEnv.get
+ val ser = env.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val successful = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksSuccessful = 0
+
+ var weight = 1
+ var minShare = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent: Pool = null
+
+ var runningTasks = 0
+ private val runningTasksSet = new HashSet[Long]
+
+ // Set of pending tasks for each executor. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+ // but at host level.
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each rack -- similar to the above.
+ private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set containing pending tasks with no locality preferences.
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // Set containing all pending tasks (also used as a stack, as above).
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ conf.get("spark.logging.exceptionPrintInterval", "10000").toLong
+
+ // Map of recent exceptions (identified by string representation and top stack frame) to
+ // duplicate count (how many times the same exception has appeared) and time the full exception
+ // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker epoch and set it on all tasks
+ val epoch = sched.mapOutputTracker.getEpoch
+ logDebug("Epoch for " + taskSet + ": " + epoch)
+ for (t <- tasks) {
+ t.epoch = epoch
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+ val myLocalityLevels = computeValidLocalityLevels()
+ val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+ // Delay scheduling variables: we keep track of our current locality level and the time we
+ // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+ // We then move down if we manage to launch a "more local" task.
+ var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
+ var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
+
+ override def schedulableQueue = null
+
+ override def schedulingMode = SchedulingMode.NONE
+
+ /**
+ * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+ * re-adding the task so only include it in each list if it's not already there.
+ */
+ private def addPendingTask(index: Int, readding: Boolean = false) {
+ // Utility method that adds `index` to a list only if readding=false or it's not already there
+ def addTo(list: ArrayBuffer[Int]) {
+ if (!readding || !list.contains(index)) {
+ list += index
+ }
+ }
+
+ var hadAliveLocations = false
+ for (loc <- tasks(index).preferredLocations) {
+ for (execId <- loc.executorId) {
+ if (sched.isExecutorAlive(execId)) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ hadAliveLocations = true
+ }
+ }
+ if (sched.hasExecutorsAliveOnHost(loc.host)) {
+ addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ for (rack <- sched.getRackForHost(loc.host)) {
+ addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+ }
+ hadAliveLocations = true
+ }
+ }
+
+ if (!hadAliveLocations) {
+ // Even though the task might've had preferred locations, all of those hosts or executors
+ // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+ addTo(pendingTasksWithNoPrefs)
+ }
+
+ if (!readding) {
+ allPendingTasks += index // No point scanning this whole list to find the old task there
+ }
+ }
+
+ /**
+ * Return the pending tasks list for a given executor ID, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+ pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending tasks list for a given host, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending rack-local task list for a given rack, or an empty list if
+ * there is no map entry for that rack
+ */
+ private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+ pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+ }
+
+ /**
+ * Dequeue a pending task from the given list and return its index.
+ * Return None if the list is empty.
+ * This method also cleans up any tasks in the list that have already
+ * been launched, since we want that to happen lazily.
+ */
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ /** Check whether a task is currently running an attempt on a given host */
+ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+ !taskAttempts(taskIndex).exists(_.host == host)
+ }
+
+ /**
+ * Return a speculative task for a given executor if any are available. The task should not have
+ * an attempt running on this host, in case the host is slow. In addition, the task should meet
+ * the given locality constraint.
+ */
+ private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+
+ if (!speculatableTasks.isEmpty) {
+ // Check for process-local or preference-less tasks; note that tasks can be process-local
+ // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val prefs = tasks(index).preferredLocations
+ val executors = prefs.flatMap(_.executorId)
+ if (prefs.size == 0 || executors.contains(execId)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+ }
+
+ // Check for node-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val locations = tasks(index).preferredLocations.map(_.host)
+ if (locations.contains(host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+ }
+ // Check for rack-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for (rack <- sched.getRackForHost(host)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ if (racks.contains(rack)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+ }
+ }
+
+ // Check for non-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+ }
+
+ return None
+ }
+
+ /**
+ * Dequeue a pending task for a given node and return its index and locality level.
+ * Only search for tasks matching the given locality constraint.
+ */
+ private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for {
+ rack <- sched.getRackForHost(host)
+ index <- findTaskFromList(getPendingTasksForRack(rack))
+ } {
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+
+ // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+ for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- findTaskFromList(allPendingTasks)) {
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(execId, host, locality)
+ }
+
+ /**
+ * Respond to an offer of a single executor from the scheduler by finding a task
+ */
def resourceOffer(
execId: String,
host: String,
availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription]
+ : Option[TaskDescription] =
+ {
+ if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+ val curTime = clock.getTime()
+
+ var allowedLocality = getAllowedLocalityLevel(curTime)
+ if (allowedLocality > maxLocality) {
+ allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
+ }
+
+ findTask(execId, host, allowedLocality) match {
+ case Some((index, taskLocality)) => {
+ // Found a task; do some bookkeeping and return a task description
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, taskLocality))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ // Update our locality level for delay scheduling
+ currentLocalityIndex = getLocalityIndex(taskLocality)
+ lastLaunchTime = curTime
+ // Serialize and return the task
+ val startTime = clock.getTime()
+ // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+ // we assume the task can be serialized without exceptions.
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = clock.getTime() - startTime
+ addRunningTask(taskId)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ if (taskAttempts(index).size == 1)
+ taskStarted(task,info)
+ return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ /**
+ * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+ */
+ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+ while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+ currentLocalityIndex < myLocalityLevels.length - 1)
+ {
+ // Jump to the next locality level, and remove our waiting time for the current one since
+ // we don't want to count it again on the next one
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ }
+ myLocalityLevels(currentLocalityIndex)
+ }
+
+ /**
+ * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+ * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+ * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+ */
+ def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+ var index = 0
+ while (locality > myLocalityLevels(index)) {
+ index += 1
+ }
+ index
+ }
+
+ private def taskStarted(task: Task[_], info: TaskInfo) {
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ }
+
+ /**
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+ */
+ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markSuccessful()
+ removeRunningTask(tid)
+ if (!successful(index)) {
+ logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+ tid, info.duration, info.host, tasksSuccessful, numTasks))
+ sched.dagScheduler.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+ // Mark successful and stop if all the tasks have succeeded.
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+ index + " has already completed successfully")
+ }
+ }
+
+ /**
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
+ */
+ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ return
+ }
+ removeRunningTask(tid)
+ val index = info.index
+ info.markFailed()
+ var failureReason = "unknown"
+ if (!successful(index)) {
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ reason.foreach {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(
+ tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ if (ef.className == classOf[NotSerializableException].getName()) {
+ // If the task result wasn't rerializable, there's no point in trying to re-execute it.
+ logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+ taskSet.id, index, ef.description))
+ abort("Task %s:%s had a not serializable result: %s".format(
+ taskSet.id, index, ef.description))
+ return
+ }
+ val key = ef.description
+ failureReason = "Exception failure: %s".format(ef.description)
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case TaskResultLost =>
+ failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+ logWarning(failureReason)
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
+ case _ => {}
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ if (state != TaskState.KILLED) {
+ numFailures(index) += 1
+ if (numFailures(index) >= maxTaskFailures) {
+ logError("Task %s:%d failed %d times; aborting job".format(
+ taskSet.id, index, maxTaskFailures))
+ abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+ taskSet.id, index, maxTaskFailures, failureReason))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def abort(message: String) {
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ removeAllRunningTasks()
+ sched.taskSetFinished(this)
+ }
+
+ /** If the given task ID is not in the set of running tasks, adds it.
+ *
+ * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+ */
+ def addRunningTask(tid: Long) {
+ if (runningTasksSet.add(tid) && parent != null) {
+ parent.increaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ /** If the given task ID is in the set of running tasks, removes it. */
+ def removeRunningTask(tid: Long) {
+ if (runningTasksSet.remove(tid) && parent != null) {
+ parent.decreaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ private[scheduler] def removeAllRunningTasks() {
+ val numRunningTasks = runningTasksSet.size
+ runningTasksSet.clear()
+ if (parent != null) {
+ parent.decreaseRunningTasks(numRunningTasks)
+ }
+ runningTasks = 0
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {}
+
+ override def removeSchedulable(schedulable: Schedulable) {}
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
+ override def executorLost(execId: String, host: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+ // task that used to have locations on only this host might now go to the no-prefs list. Note
+ // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+ // locations), because findTaskFromList will skip already-running tasks.
+ for (index <- getPendingTasksForExecutor(execId)) {
+ addPendingTask(index, readding=true)
+ }
+ for (index <- getPendingTasksForHost(host)) {
+ addPendingTask(index, readding=true)
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (successful(index)) {
+ successful(index) = false
+ copiesRunning(index) -= 1
+ tasksSuccessful -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ handleFailedTask(tid, TaskState.KILLED, None)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the TaskScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksSuccessful == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+ val time = clock.getTime()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+ val defaultWait = conf.get("spark.locality.wait", "3000")
+ level match {
+ case TaskLocality.PROCESS_LOCAL =>
+ conf.get("spark.locality.wait.process", defaultWait).toLong
+ case TaskLocality.NODE_LOCAL =>
+ conf.get("spark.locality.wait.node", defaultWait).toLong
+ case TaskLocality.RACK_LOCAL =>
+ conf.get("spark.locality.wait.rack", defaultWait).toLong
+ case TaskLocality.ANY =>
+ 0L
+ }
+ }
- def error(message: String)
+ /**
+ * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+ * added to queues using addPendingTask.
+ */
+ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+ import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+ val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+ if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+ levels += PROCESS_LOCAL
+ }
+ if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+ levels += NODE_LOCAL
+ }
+ if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+ levels += RACK_LOCAL
+ }
+ levels += ANY
+ logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+ levels.toArray
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index 938f62883a..ba6bab3f91 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
/**
* Represents free resources available on an executor.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
deleted file mode 100644
index 4c5eca8537..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ /dev/null
@@ -1,712 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.io.NotSerializableException
-import java.util.Arrays
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
-import org.apache.spark.util.{SystemClock, Clock}
-
-
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
- * the status of each task, retries tasks if they fail (up to a limited number of times), and
- * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
- * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
- * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
- *
- * THREADING: This class is designed to only be called from code with a lock on the
- * ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
- */
-private[spark] class ClusterTaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet,
- clock: Clock = SystemClock)
- extends TaskSetManager
- with Logging
-{
- // CPUs to request per task
- val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
-
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val successful = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksSuccessful = 0
-
- var weight = 1
- var minShare = 0
- var priority = taskSet.priority
- var stageId = taskSet.stageId
- var name = "TaskSet_"+taskSet.stageId.toString
- var parent: Pool = null
-
- var runningTasks = 0
- private val runningTasksSet = new HashSet[Long]
-
- // Set of pending tasks for each executor. These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
- // but at host level.
- private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each rack -- similar to the above.
- private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
-
- // Set containing pending tasks with no locality preferences.
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // Set containing all pending tasks (also used as a stack, as above).
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the TaskSet fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
-
- // Map of recent exceptions (identified by string representation and top stack frame) to
- // duplicate count (how many times the same exception has appeared) and time the full exception
- // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker epoch and set it on all tasks
- val epoch = sched.mapOutputTracker.getEpoch
- logDebug("Epoch for " + taskSet + ": " + epoch)
- for (t <- tasks) {
- t.epoch = epoch
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
- val myLocalityLevels = computeValidLocalityLevels()
- val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
-
- // Delay scheduling variables: we keep track of our current locality level and the time we
- // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
- // We then move down if we manage to launch a "more local" task.
- var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
-
- /**
- * Add a task to all the pending-task lists that it should be on. If readding is set, we are
- * re-adding the task so only include it in each list if it's not already there.
- */
- private def addPendingTask(index: Int, readding: Boolean = false) {
- // Utility method that adds `index` to a list only if readding=false or it's not already there
- def addTo(list: ArrayBuffer[Int]) {
- if (!readding || !list.contains(index)) {
- list += index
- }
- }
-
- var hadAliveLocations = false
- for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- if (sched.isExecutorAlive(execId)) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
- hadAliveLocations = true
- }
- }
- if (sched.hasExecutorsAliveOnHost(loc.host)) {
- addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
- for (rack <- sched.getRackForHost(loc.host)) {
- addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
- }
- hadAliveLocations = true
- }
- }
-
- if (!hadAliveLocations) {
- // Even though the task might've had preferred locations, all of those hosts or executors
- // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
- addTo(pendingTasksWithNoPrefs)
- }
-
- if (!readding) {
- allPendingTasks += index // No point scanning this whole list to find the old task there
- }
- }
-
- /**
- * Return the pending tasks list for a given executor ID, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
- pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
- }
-
- /**
- * Return the pending tasks list for a given host, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- /**
- * Return the pending rack-local task list for a given rack, or an empty list if
- * there is no map entry for that rack
- */
- private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
- pendingTasksForRack.getOrElse(rack, ArrayBuffer())
- }
-
- /**
- * Dequeue a pending task from the given list and return its index.
- * Return None if the list is empty.
- * This method also cleans up any tasks in the list that have already
- * been launched, since we want that to happen lazily.
- */
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !successful(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- /** Check whether a task is currently running an attempt on a given host */
- private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
- !taskAttempts(taskIndex).exists(_.host == host)
- }
-
- /**
- * Return a speculative task for a given executor if any are available. The task should not have
- * an attempt running on this host, in case the host is slow. In addition, the task should meet
- * the given locality constraint.
- */
- private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
-
- if (!speculatableTasks.isEmpty) {
- // Check for process-local or preference-less tasks; note that tasks can be process-local
- // on multiple nodes when we replicate cached blocks, as in Spark Streaming
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
- if (prefs.size == 0 || executors.contains(execId)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
- }
-
- // Check for node-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val locations = tasks(index).preferredLocations.map(_.host)
- if (locations.contains(host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
- }
-
- // Check for rack-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for (rack <- sched.getRackForHost(host)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
- if (racks.contains(rack)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
- }
- }
-
- // Check for non-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.ANY))
- }
- }
- }
-
- return None
- }
-
- /**
- * Dequeue a pending task for a given node and return its index and locality level.
- * Only search for tasks matching the given locality constraint.
- */
- private def findTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(getPendingTasksForHost(host))) {
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for {
- rack <- sched.getRackForHost(host)
- index <- findTaskFromList(getPendingTasksForRack(rack))
- } {
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
-
- // Look for no-pref tasks after rack-local tasks since they can run anywhere.
- for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(allPendingTasks)) {
- return Some((index, TaskLocality.ANY))
- }
- }
-
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(execId, host, locality)
- }
-
- /**
- * Respond to an offer of a single executor from the scheduler by finding a task
- */
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
- val curTime = clock.getTime()
-
- var allowedLocality = getAllowedLocalityLevel(curTime)
- if (allowedLocality > maxLocality) {
- allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
- }
-
- findTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality)) => {
- // Found a task; do some bookkeeping and return a task description
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, taskLocality))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- // Update our locality level for delay scheduling
- currentLocalityIndex = getLocalityIndex(taskLocality)
- lastLaunchTime = curTime
- // Serialize and return the task
- val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = clock.getTime() - startTime
- addRunningTask(taskId)
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- if (taskAttempts(index).size == 1)
- taskStarted(task,info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- /**
- * Get the level we can launch tasks according to delay scheduling, based on current wait time.
- */
- private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
- while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
- currentLocalityIndex < myLocalityLevels.length - 1)
- {
- // Jump to the next locality level, and remove our waiting time for the current one since
- // we don't want to count it again on the next one
- lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
- }
- myLocalityLevels(currentLocalityIndex)
- }
-
- /**
- * Find the index in myLocalityLevels for a given locality. This is also designed to work with
- * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
- * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
- */
- def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
- var index = 0
- while (locality > myLocalityLevels(index)) {
- index += 1
- }
- index
- }
-
- private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def handleTaskGettingResult(tid: Long) = {
- val info = taskInfos(tid)
- info.markGettingResult()
- sched.dagScheduler.taskGettingResult(tasks(info.index), info)
- }
-
- /**
- * Marks the task as successful and notifies the DAGScheduler that a task has ended.
- */
- def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
- val info = taskInfos(tid)
- val index = info.index
- info.markSuccessful()
- removeRunningTask(tid)
- if (!successful(index)) {
- logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.dagScheduler.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
- // Mark successful and stop if all the tasks have succeeded.
- tasksSuccessful += 1
- successful(index) = true
- if (tasksSuccessful == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignorning task-finished event for TID " + tid + " because task " +
- index + " has already completed successfully")
- }
- }
-
- /**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
- * DAG Scheduler.
- */
- def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
- val info = taskInfos(tid)
- if (info.failed) {
- return
- }
- removeRunningTask(tid)
- val index = info.index
- info.markFailed()
- if (!successful(index)) {
- logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- reason.foreach {
- case fetchFailed: FetchFailed =>
- logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case TaskKilled =>
- logWarning("Task %d was killed.".format(tid))
- sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
- return
-
- case ef: ExceptionFailure =>
- sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- if (ef.className == classOf[NotSerializableException].getName()) {
- // If the task result wasn't serializable, there's no point in trying to re-execute it.
- logError("Task %s:%s had a not serializable result: %s; not retrying".format(
- taskSet.id, index, ef.description))
- abort("Task %s:%s had a not serializable result: %s".format(
- taskSet.id, index, ef.description))
- return
- }
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logWarning("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
- }
-
- case TaskResultLost =>
- logWarning("Lost result for TID %s on host %s".format(tid, info.host))
- sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
-
- case _ => {}
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- if (state != TaskState.KILLED) {
- numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- override def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
- removeAllRunningTasks()
- sched.taskSetFinished(this)
- }
-
- /** If the given task ID is not in the set of running tasks, adds it.
- *
- * Used to keep track of the number of running tasks, for enforcing scheduling policies.
- */
- def addRunningTask(tid: Long) {
- if (runningTasksSet.add(tid) && parent != null) {
- parent.increaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- /** If the given task ID is in the set of running tasks, removes it. */
- def removeRunningTask(tid: Long) {
- if (runningTasksSet.remove(tid) && parent != null) {
- parent.decreaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- private def removeAllRunningTasks() {
- val numRunningTasks = runningTasksSet.size
- runningTasksSet.clear()
- if (parent != null) {
- parent.decreaseRunningTasks(numRunningTasks)
- }
- runningTasks = 0
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def addSchedulable(schedulable: Schedulable) {}
-
- override def removeSchedulable(schedulable: Schedulable) {}
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
- override def executorLost(execId: String, host: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
-
- // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
- // task that used to have locations on only this host might now go to the no-prefs list. Note
- // that it's okay if we add a task to the same queue twice (if it had multiple preferred
- // locations), because findTaskFromList will skip already-running tasks.
- for (index <- getPendingTasksForExecutor(execId)) {
- addPendingTask(index, readding=true)
- }
- for (index <- getPendingTasksForHost(host)) {
- addPendingTask(index, readding=true)
- }
-
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (successful(index)) {
- successful(index) = false
- copiesRunning(index) -= 1
- tasksSuccessful -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.KILLED, None)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- override def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksSuccessful == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
- val time = clock.getTime()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
-
- override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksSuccessful < numTasks
- }
-
- private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
- val defaultWait = System.getProperty("spark.locality.wait", "3000")
- level match {
- case TaskLocality.PROCESS_LOCAL =>
- System.getProperty("spark.locality.wait.process", defaultWait).toLong
- case TaskLocality.NODE_LOCAL =>
- System.getProperty("spark.locality.wait.node", defaultWait).toLong
- case TaskLocality.RACK_LOCAL =>
- System.getProperty("spark.locality.wait.rack", defaultWait).toLong
- case TaskLocality.ANY =>
- 0L
- }
- }
-
- /**
- * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
- * added to queues using addPendingTask.
- */
- private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
- import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
- val levels = new ArrayBuffer[TaskLocality.TaskLocality]
- if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
- levels += PROCESS_LOCAL
- }
- if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
- levels += NODE_LOCAL
- }
- if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
- levels += RACK_LOCAL
- }
- levels += ANY
- logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
- levels.toArray
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index d0ba5bf55d..2f5bcafe40 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -20,18 +20,19 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
-import akka.util.Duration
-import akka.util.duration._
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.spark.{SparkException, Logging, TaskState}
-import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.{Logging, SparkException, TaskState}
+import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription,
+ WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -42,26 +43,28 @@ import org.apache.spark.util.Utils
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
+ val conf = scheduler.sc.conf
+ private val timeout = AkkaUtils.askTimeout(conf)
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
private val executorActor = new HashMap[String, ActorRef]
private val executorAddress = new HashMap[String, Address]
private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
- private val actorToExecutorId = new HashMap[ActorRef, String]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
// Periodically revive offers to allow delay scheduling to work
- val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong
+ val reviveInterval = conf.get("spark.scheduler.revive.interval", "1000").toLong
+ import context.dispatcher
context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
}
@@ -73,12 +76,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor(sparkProperties)
- context.watch(sender)
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
- actorToExecutorId(sender) = executorId
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
@@ -118,14 +119,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
removeExecutor(executorId, reason)
sender ! true
- case Terminated(actor) =>
- actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
+ case DisassociatedEvent(_, address, _) =>
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated"))
- case RemoteClientDisconnected(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
-
- case RemoteClientShutdown(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
// Make fake resource offers on all executors
@@ -153,7 +149,6 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
if (executorActor.contains(executorId)) {
logInfo("Executor " + executorId + " disconnected, so removing it")
val numCores = freeCores(executorId)
- actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
executorHost -= executorId
@@ -169,22 +164,16 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
override def start() {
val properties = new ArrayBuffer[(String, String)]
- val iterator = System.getProperties.entrySet.iterator
- while (iterator.hasNext) {
- val entry = iterator.next
- val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ for ((key, value) <- scheduler.sc.conf.getAll) {
if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
properties += ((key, value))
}
}
+ //TODO (prashant) send conf instead of properties
driverActor = actorSystem.actorOf(
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = {
- Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
- }
-
def stopExecutors() {
try {
if (driverActor != null) {
@@ -219,8 +208,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
driverActor ! KillTask(taskId, executorId)
}
- override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
- .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
+ override def defaultParallelism(): Int = {
+ conf.getOption("spark.default.parallelism").map(_.toInt).getOrElse(
+ math.max(totalCoreCount.get(), 2))
+ }
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index e000531a26..b44d1e43c8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
+
import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class SimrSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
driverFilePath: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
@@ -31,13 +33,13 @@ private[spark] class SimrSchedulerBackend(
val tmpPath = new Path(driverFilePath + "_tmp")
val filePath = new Path(driverFilePath)
- val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
+ val maxCores = conf.get("spark.simr.executor.cores", "1").toInt
override def start() {
super.start()
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val conf = new Configuration()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cefa970bb9..73fc37444e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,14 +17,16 @@
package org.apache.spark.scheduler.cluster
+import scala.collection.mutable.HashMap
+
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.deploy.client.{Client, ClientListener}
import org.apache.spark.deploy.{Command, ApplicationDescription}
-import scala.collection.mutable.HashMap
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
masters: Array[String],
appName: String)
@@ -36,23 +38,23 @@ private[spark] class SparkDeploySchedulerBackend(
var stopping = false
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
- val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
+ val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
override def start() {
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ conf.get("spark.driver.host"), conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command(
"org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
- val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
+ val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
- client = new Client(sc.env.actorSystem, masters, appDesc, this)
+ client = new Client(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index cd521e0f2b..d46fceba89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,7 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu
* remove this.
*/
private[spark] class CoarseMesosSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String,
appName: String)
@@ -61,7 +62,7 @@ private[spark] class CoarseMesosSchedulerBackend(
var driver: SchedulerDriver = null
// Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
- val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
+ val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt
// Cores we have acquired with each Mesos task ID
val coresByTaskId = new HashMap[Int, Int]
@@ -76,7 +77,7 @@ private[spark] class CoarseMesosSchedulerBackend(
"Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor"))
- val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
+ val extraCoresPerSlave = conf.get("spark.mesos.extra.cores", "0").toInt
var nextMesosTaskId = 0
@@ -120,13 +121,13 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
- System.getProperty("spark.driver.host"),
- System.getProperty("spark.driver.port"),
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ conf.get("spark.driver.host"),
+ conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val uri = System.getProperty("spark.executor.uri")
+ val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
- val runScript = new File(sparkHome, "spark-class").getCanonicalPath
+ val runScript = new File(sparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
"\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
@@ -135,7 +136,7 @@ private[spark] class CoarseMesosSchedulerBackend(
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
+ "cd %s*; ./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
.format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
@@ -176,7 +177,7 @@ private[spark] class CoarseMesosSchedulerBackend(
val slaveId = offer.getSlaveId.toString
val mem = getResource(offer.getResourcesList, "mem")
val cpus = getResource(offer.getResourcesList, "cpus").toInt
- if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 &&
+ if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 &&
failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
!slaveIdsWithExecutors.contains(slaveId)) {
// Launch an executor on the slave
@@ -192,7 +193,7 @@ private[spark] class CoarseMesosSchedulerBackend(
.setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
.setName("Task " + taskId)
.addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem", executorMemory))
+ .addResources(createResource("mem", sc.executorMemory))
.build()
d.launchTasks(offer.getId, Collections.singletonList(task), filters)
} else {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 50cbc2ca92..ae8d527352 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,9 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
-import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason}
-import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost,
+ TaskDescription, TaskSchedulerImpl, WorkerOffer}
import org.apache.spark.util.Utils
/**
@@ -41,7 +40,7 @@ import org.apache.spark.util.Utils
* from multiple apps can run on different cores) and in time (a core can switch ownership).
*/
private[spark] class MesosSchedulerBackend(
- scheduler: ClusterScheduler,
+ scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String,
appName: String)
@@ -101,20 +100,20 @@ private[spark] class MesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val uri = System.getProperty("spark.executor.uri")
+ val uri = sc.conf.get("spark.executor.uri", null)
if (uri == null) {
- command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath)
+ command.setValue(new File(sparkHome, "/sbin/spark-executor").getCanonicalPath)
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; ./spark-executor".format(basename))
+ command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
+ .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build())
.build()
ExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
@@ -199,7 +198,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
- mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
+ mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
@@ -341,5 +340,5 @@ private[spark] class MesosSchedulerBackend(
}
// TODO: query Mesos for number of cores
- override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt
+ override def defaultParallelism() = sc.conf.get("spark.default.parallelism", "8").toInt
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
new file mode 100644
index 0000000000..897d47a9ad
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Actor, ActorRef, Props}
+
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+
+private case class ReviveOffers()
+
+private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private case class KillTask(taskId: Long)
+
+/**
+ * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
+ * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
+ * and the ClusterScheduler.
+ */
+private[spark] class LocalActor(
+ scheduler: TaskSchedulerImpl,
+ executorBackend: LocalBackend,
+ private val totalCores: Int) extends Actor with Logging {
+
+ private var freeCores = totalCores
+
+ private val localExecutorId = "localhost"
+ private val localExecutorHostname = "localhost"
+
+ val executor = new Executor(
+ localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true)
+
+ def receive = {
+ case ReviveOffers =>
+ reviveOffers()
+
+ case StatusUpdate(taskId, state, serializedData) =>
+ scheduler.statusUpdate(taskId, state, serializedData)
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ reviveOffers()
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
+ }
+
+ def reviveOffers() {
+ val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+ for (task <- scheduler.resourceOffers(offers).flatten) {
+ freeCores -= 1
+ executor.launchTask(executorBackend, task.taskId, task.serializedTask)
+ }
+ }
+}
+
+/**
+ * LocalBackend is used when running a local version of Spark where the executor, backend, and
+ * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * on a single Executor (created by the LocalBackend) running locally.
+ */
+private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
+ extends SchedulerBackend with ExecutorBackend {
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ localActor = SparkEnv.get.actorSystem.actorOf(
+ Props(new LocalActor(scheduler, this, totalCores)),
+ "LocalBackendActor")
+ }
+
+ override def stop() {
+ }
+
+ override def reviveOffers() {
+ localActor ! ReviveOffers
+ }
+
+ override def defaultParallelism() = totalCores
+
+ override def killTask(taskId: Long, executorId: String) {
+ localActor ! KillTask(taskId)
+ }
+
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ localActor ! StatusUpdate(taskId, state, serializedData)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
deleted file mode 100644
index 2699f0b33e..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-
-import akka.actor._
-
-import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-
-/**
- * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
- * the scheduler also allows each task to fail up to maxFailures times, which is useful for
- * testing fault recovery.
- */
-
-private[local]
-case class LocalReviveOffers()
-
-private[local]
-case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-
-private[local]
-case class KillTask(taskId: Long)
-
-private[spark]
-class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
- extends Actor with Logging {
-
- val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
-
- def receive = {
- case LocalReviveOffers =>
- launchTask(localScheduler.resourceOffer(freeCores))
-
- case LocalStatusUpdate(taskId, state, serializeData) =>
- if (TaskState.isFinished(state)) {
- freeCores += 1
- launchTask(localScheduler.resourceOffer(freeCores))
- }
-
- case KillTask(taskId) =>
- executor.killTask(taskId)
- }
-
- private def launchTask(tasks: Seq[TaskDescription]) {
- for (task <- tasks) {
- freeCores -= 1
- executor.launchTask(localScheduler, task.taskId, task.serializedTask)
- }
- }
-}
-
-private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
- extends TaskScheduler
- with ExecutorBackend
- with Logging {
-
- val env = SparkEnv.get
- val attemptId = new AtomicInteger
- var dagScheduler: DAGScheduler = null
-
- // Application dependencies (added through SparkContext) that we've fetched so far on this node.
- // Each map holds the master's timestamp for the version of that file or JAR we got.
- val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
- val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
-
- var schedulableBuilder: SchedulableBuilder = null
- var rootPool: Pool = null
- val schedulingMode: SchedulingMode = SchedulingMode.withName(
- System.getProperty("spark.scheduler.mode", "FIFO"))
- val activeTaskSets = new HashMap[String, LocalTaskSetManager]
- val taskIdToTaskSetId = new HashMap[Long, String]
- val taskSetTaskIds = new HashMap[String, HashSet[Long]]
-
- var localActor: ActorRef = null
-
- override def start() {
- // temporarily set rootPool name to empty
- rootPool = new Pool("", schedulingMode, 0, 0)
- schedulableBuilder = {
- schedulingMode match {
- case SchedulingMode.FIFO =>
- new FIFOSchedulableBuilder(rootPool)
- case SchedulingMode.FAIR =>
- new FairSchedulableBuilder(rootPool)
- }
- }
- schedulableBuilder.buildPools()
-
- localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
- }
-
- override def setDAGScheduler(dagScheduler: DAGScheduler) {
- this.dagScheduler = dagScheduler
- }
-
- override def submitTasks(taskSet: TaskSet) {
- synchronized {
- val manager = new LocalTaskSetManager(this, taskSet)
- schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
- activeTaskSets(taskSet.id) = manager
- taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- localActor ! LocalReviveOffers
- }
- }
-
- override def cancelTasks(stageId: Int): Unit = synchronized {
- logInfo("Cancelling stage " + stageId)
- logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- val taskIds = taskSetTaskIds(tsm.taskSet.id)
- if (taskIds.size > 0) {
- taskIds.foreach { tid =>
- localActor ! KillTask(tid)
- }
- }
- tsm.error("Stage %d was cancelled".format(stageId))
- }
- }
-
- def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
- synchronized {
- var freeCpuCores = freeCores
- val tasks = new ArrayBuffer[TaskDescription](freeCores)
- val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
- for (manager <- sortedTaskSetQueue) {
- logDebug("parentName:%s,name:%s,runningTasks:%s".format(
- manager.parent.name, manager.name, manager.runningTasks))
- }
-
- var launchTask = false
- for (manager <- sortedTaskSetQueue) {
- do {
- launchTask = false
- manager.resourceOffer(null, null, freeCpuCores, null) match {
- case Some(task) =>
- tasks += task
- taskIdToTaskSetId(task.taskId) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += task.taskId
- freeCpuCores -= 1
- launchTask = true
- case None => {}
- }
- } while(launchTask)
- }
- return tasks
- }
- }
-
- def taskSetFinished(manager: TaskSetManager) {
- synchronized {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds -= manager.taskSet.id
- }
- }
-
- override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- if (TaskState.isFinished(state)) {
- synchronized {
- taskIdToTaskSetId.get(taskId) match {
- case Some(taskSetId) =>
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
-
- state match {
- case TaskState.FINISHED =>
- taskSetManager.taskEnded(taskId, state, serializedData)
- case TaskState.FAILED =>
- taskSetManager.taskFailed(taskId, state, serializedData)
- case TaskState.KILLED =>
- taskSetManager.error("Task %d was killed".format(taskId))
- case _ => {}
- }
- case None =>
- logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
- }
- }
- localActor ! LocalStatusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- }
-
- override def defaultParallelism() = threads
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
deleted file mode 100644
index 53bf78267e..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
- TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
-
-
-private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
- extends TaskSetManager with Logging {
-
- var parent: Pool = null
- var weight: Int = 1
- var minShare: Int = 0
- var runningTasks: Int = 0
- var priority: Int = taskSet.priority
- var stageId: Int = taskSet.stageId
- var name: String = "TaskSet_" + taskSet.stageId.toString
-
- var failCount = new Array[Int](taskSet.tasks.size)
- val taskInfos = new HashMap[Long, TaskInfo]
- val numTasks = taskSet.tasks.size
- var numFinished = 0
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val MAX_TASK_FAILURES = sched.maxFailures
-
- def increaseRunningTasks(taskNum: Int): Unit = {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
- }
- }
-
- def decreaseRunningTasks(taskNum: Int): Unit = {
- runningTasks -= taskNum
- if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
- }
- }
-
- override def addSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def removeSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def executorLost(executorId: String, host: String): Unit = {
- // nothing
- }
-
- override def checkSpeculatableTasks() = true
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- override def hasPendingTasks() = true
-
- def findTask(): Option[Int] = {
- for (i <- 0 to numTasks-1) {
- if (copiesRunning(i) == 0 && !finished(i)) {
- return Some(i)
- }
- }
- return None
- }
-
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- SparkEnv.set(sched.env)
- logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
- availableCpus.toInt, numFinished, numTasks))
- if (availableCpus > 0 && numFinished < numTasks) {
- findTask() match {
- case Some(index) =>
- val taskId = sched.attemptId.getAndIncrement()
- val task = taskSet.tasks(index)
- val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
- TaskLocality.NODE_LOCAL)
- taskInfos(taskId) = info
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val bytes = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
- val taskName = "task %s:%d".format(taskSet.id, index)
- copiesRunning(index) += 1
- increaseRunningTasks(1)
- taskStarted(task, info)
- return Some(new TaskDescription(taskId, null, taskName, index, bytes))
- case None => {}
- }
- }
- return None
- }
-
- def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) => {
- throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
- }
- }
- result.metrics.resultSize = serializedData.limit()
- sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
- result.metrics)
- numFinished += 1
- decreaseRunningTasks(1)
- finished(index) = true
- if (numFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- }
-
- def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markFailed()
- decreaseRunningTasks(1)
- val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
- serializedData, getClass.getClassLoader)
- sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
- if (!finished(index)) {
- copiesRunning(index) -= 1
- numFailures(index) += 1
- val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- reason.className, reason.description, locs.mkString("\n")))
- if (numFailures(index) > MAX_TASK_FAILURES) {
- val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, MAX_TASK_FAILURES, reason.description)
- decreaseRunningTasks(runningTasks)
- sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
- // need to delete failed Taskset from schedule queue
- sched.taskSetFinished(this)
- }
- }
- }
-
- override def error(message: String) {
- sched.dagScheduler.taskSetFailed(taskSet, message)
- sched.taskSetFinished(this)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 4de81617b1..5d3d43623d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -21,6 +21,7 @@ import java.io._
import java.nio.ByteBuffer
import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.SparkConf
private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
@@ -77,6 +78,6 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
/**
* A Spark serializer that uses Java's built-in serialization.
*/
-class JavaSerializer extends Serializer {
+class JavaSerializer(conf: SparkConf) extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index e748c2275d..a24a3b04b8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -25,18 +25,18 @@ import com.esotericsoftware.kryo.{KryoException, Kryo}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
-import org.apache.spark.{SerializableWritable, Logging}
+import org.apache.spark._
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
+import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
-class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
-
+class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
private val bufferSize = {
- System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ conf.get("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
}
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -48,7 +48,7 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
- kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+ kryo.setReferences(conf.get("spark.kryo.referenceTracking", "true").toBoolean)
for (cls <- KryoSerializer.toRegister) kryo.register(cls)
@@ -58,13 +58,13 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
// Allow the user to register their own classes by setting spark.kryo.registrator
try {
- Option(System.getProperty("spark.kryo.registrator")).foreach { regCls =>
+ for (regCls <- conf.getOption("spark.kryo.registrator")) {
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo)
}
} catch {
- case _: Exception => println("Failed to register spark.kryo.registrator")
+ case e: Exception => logError("Failed to run spark.kryo.registrator", e)
}
// Register Chill's classes; we do this after our ranges and the user's own classes to let
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 160cca4d6c..9a5e3cb77e 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -29,6 +29,9 @@ import org.apache.spark.util.{NextIterator, ByteBufferInputStream}
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual serialization and are
* guaranteed to only be called from one thread at a time.
+ *
+ * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a
+ * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes precedence.
*/
trait Serializer {
def newInstance(): SerializerInstance
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 2955986fec..36a37af4f8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -18,6 +18,7 @@
package org.apache.spark.serializer
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.SparkConf
/**
@@ -26,18 +27,19 @@ import java.util.concurrent.ConcurrentHashMap
* creating a new one.
*/
private[spark] class SerializerManager {
+ // TODO: Consider moving this into SparkConf itself to remove the global singleton.
private val serializers = new ConcurrentHashMap[String, Serializer]
private var _default: Serializer = _
def default = _default
- def setDefault(clsName: String): Serializer = {
- _default = get(clsName)
+ def setDefault(clsName: String, conf: SparkConf): Serializer = {
+ _default = get(clsName, conf)
_default
}
- def get(clsName: String): Serializer = {
+ def get(clsName: String, conf: SparkConf): Serializer = {
if (clsName == null) {
default
} else {
@@ -51,8 +53,19 @@ private[spark] class SerializerManager {
serializer = serializers.get(clsName)
if (serializer == null) {
val clsLoader = Thread.currentThread.getContextClassLoader
- serializer =
- Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ val cls = Class.forName(clsName, true, clsLoader)
+
+ // First try with the constructor that takes SparkConf. If we can't find one,
+ // use a no-arg constructor instead.
+ try {
+ val constructor = cls.getConstructor(classOf[SparkConf])
+ serializer = constructor.newInstance(conf).asInstanceOf[Serializer]
+ } catch {
+ case _: NoSuchMethodException =>
+ val constructor = cls.getConstructor()
+ serializer = constructor.newInstance().asInstanceOf[Serializer]
+ }
+
serializers.put(clsName, serializer)
}
serializer
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index e51c5b30a3..47478631a1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -312,7 +312,7 @@ object BlockFetcherIterator {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.host))
val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
- val cpier = new ShuffleCopier
+ val cpier = new ShuffleCopier(blockManager.conf)
cpier.getBlocks(cmId, req.blocks, putResult)
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
}
@@ -327,7 +327,7 @@ object BlockFetcherIterator {
fetchRequestsSync.put(request)
}
- copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ copiers = startCopiers(conf.get("spark.shuffle.copier.threads", "6").toInt)
logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
Utils.getUsedTimeMs(startTime))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 702aca8323..6d2cda97b0 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -24,13 +24,13 @@ import scala.collection.mutable.{HashMap, ArrayBuffer}
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
-import akka.dispatch.{Await, Future}
-import akka.util.Duration
-import akka.util.duration._
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.{SparkConf, Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
@@ -43,12 +43,13 @@ private[spark] class BlockManager(
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val defaultSerializer: Serializer,
- maxMemory: Long)
+ maxMemory: Long,
+ val conf: SparkConf)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
- System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -57,12 +58,12 @@ private[spark] class BlockManager(
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
- val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
- val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
+ val useNetty = conf.get("spark.shuffle.use.netty", "false").toBoolean
+ val nettyPortConfig = conf.get("spark.shuffle.sender.port", "0").toInt
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
- val connectionManager = new ConnectionManager(0)
+ val connectionManager = new ConnectionManager(0, conf)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
@@ -71,18 +72,18 @@ private[spark] class BlockManager(
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
val maxBytesInFlight =
- System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+ conf.get("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
// Whether to compress broadcast variables that are stored
- val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ val compressBroadcast = conf.get("spark.broadcast.compress", "true").toBoolean
// Whether to compress shuffle output that are stored
- val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean
+ val compressShuffle = conf.get("spark.shuffle.compress", "true").toBoolean
// Whether to compress RDD partitions that are stored serialized
- val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean
+ val compressRdds = conf.get("spark.rdd.compress", "false").toBoolean
- val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
+ val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val hostPort = Utils.localHostPort()
+ val hostPort = Utils.localHostPort(conf)
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -100,8 +101,11 @@ private[spark] class BlockManager(
var heartBeatTask: Cancellable = null
- private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks)
- private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks)
+ private val metadataCleaner = new MetadataCleaner(
+ MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf)
+ private val broadcastCleaner = new MetadataCleaner(
+ MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf)
+
initialize()
// The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -109,14 +113,14 @@ private[spark] class BlockManager(
// program could be using a user-defined codec in a third party jar, which is loaded in
// Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
// loaded yet.
- private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec()
+ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
+ serializer: Serializer, conf: SparkConf) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf)
}
/**
@@ -126,7 +130,7 @@ private[spark] class BlockManager(
private def initialize() {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
BlockManagerWorker.startBlockManagerWorker(this)
- if (!BlockManager.getDisableHeartBeatsForTesting) {
+ if (!BlockManager.getDisableHeartBeatsForTesting(conf)) {
heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) {
heartBeat()
}
@@ -439,7 +443,7 @@ private[spark] class BlockManager(
: BlockFetcherIterator = {
val iter =
- if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+ if (conf.get("spark.shuffle.use.netty", "false").toBoolean) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
@@ -465,7 +469,8 @@ private[spark] class BlockManager(
def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
- new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
+ val syncWrites = conf.get("spark.shuffle.sync", "false").toBoolean
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
}
/**
@@ -856,19 +861,18 @@ private[spark] class BlockManager(
private[spark] object BlockManager extends Logging {
-
val ID_GENERATOR = new IdGenerator
- def getMaxMemoryFromSystemProperties: Long = {
- val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
+ def getMaxMemory(conf: SparkConf): Long = {
+ val memoryFraction = conf.get("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
}
- def getHeartBeatFrequencyFromSystemProperties: Long =
- System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
+ def getHeartBeatFrequency(conf: SparkConf): Long =
+ conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
- def getDisableHeartBeatsForTesting: Boolean =
- System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
+ def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean =
+ conf.get("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
/**
* Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
@@ -924,4 +928,3 @@ private[spark] object BlockManager extends Logging {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 94038649b3..51a29ed8ef 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -17,23 +17,25 @@
package org.apache.spark.storage
-import akka.actor.ActorRef
-import akka.dispatch.{Await, Future}
+import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import akka.actor._
import akka.pattern.ask
-import akka.util.Duration
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{SparkConf, Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
+import org.apache.spark.util.AkkaUtils
+private[spark]
+class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Logging {
-private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
-
- val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
- val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
+ val AKKA_RETRY_ATTEMPTS: Int = conf.get("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_INTERVAL_MS: Int = conf.get("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
- val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout(conf)
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index f8cf14b503..58452d9657 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -21,24 +21,22 @@ import java.util.{HashMap => JHashMap}
import scala.collection.mutable
import scala.collection.JavaConversions._
+import scala.concurrent.Future
+import scala.concurrent.duration._
import akka.actor.{Actor, ActorRef, Cancellable}
-import akka.dispatch.Future
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{SparkConf, Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* BlockManagerMasterActor is an actor on the master node to track statuses of
* all slaves' block managers.
*/
private[spark]
-class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
+class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Actor with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo =
@@ -50,21 +48,19 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
- val akkaTimeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
- initLogging()
+ private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
- "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
+ val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs",
+ "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong
- val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
+ val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs",
"60000").toLong
var timeoutCheckingTask: Cancellable = null
override def preStart() {
- if (!BlockManager.getDisableHeartBeatsForTesting) {
+ if (!BlockManager.getDisableHeartBeatsForTesting(conf)) {
+ import context.dispatcher
timeoutCheckingTask = context.system.scheduler.schedule(
0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 0c66addf9d..21f003609b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -30,7 +30,6 @@ import org.apache.spark.util.Utils
* TODO: Use event model.
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
- initLogging()
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
@@ -101,8 +100,6 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
- initLogging()
-
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index 6ce9127c74..a06f50a0ac 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -37,8 +37,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM
def length = blockMessages.length
- initLogging()
-
def set(bufferMessage: BufferMessage) {
val startTime = System.currentTimeMillis
val newBlockMessages = new ArrayBuffer[BlockMessage]()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index b4451fc7b8..61e63c60d5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -74,7 +74,8 @@ class DiskBlockObjectWriter(
file: File,
serializer: Serializer,
bufferSize: Int,
- compressStream: OutputStream => OutputStream)
+ compressStream: OutputStream => OutputStream,
+ syncWrites: Boolean)
extends BlockObjectWriter(blockId)
with Logging
{
@@ -97,8 +98,6 @@ class DiskBlockObjectWriter(
override def flush() = out.flush()
}
- private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
-
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index fcd2e97982..55dcb3742c 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ private val subDirsPerLocalDir = shuffleManager.conf.get("spark.diskStore.subDirectories", "64").toInt
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 2f1b049ce4..39dc7bb19a 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -59,32 +59,40 @@ private[spark] trait ShuffleWriterGroup {
*/
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ def conf = blockManager.conf
+
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
- System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+ conf.get("spark.shuffle.consolidateFiles", "false").toBoolean
- private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ private val bufferSize = conf.get("spark.shuffle.file.buffer.kb", "100").toInt * 1024
/**
* Contains all the state related to a particular shuffle. This includes a pool of unused
* ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
*/
- private class ShuffleState() {
+ private class ShuffleState(val numBuckets: Int) {
val nextFileId = new AtomicInteger(0)
val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+
+ /**
+ * The mapIds of all map tasks completed on this Executor for this shuffle.
+ * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise.
+ */
+ val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}
type ShuffleId = Int
private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
- private
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleWriterGroup {
- shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
private val shuffleState = shuffleStates(shuffleId)
private var fileGroup: ShuffleFileGroup = null
@@ -109,6 +117,8 @@ class ShuffleBlockManager(blockManager: BlockManager) {
fileGroup.recordMapOutput(mapId, offsets)
}
recycleFileGroup(fileGroup)
+ } else {
+ shuffleState.completedMapTasks.add(mapId)
}
}
@@ -154,7 +164,18 @@ class ShuffleBlockManager(blockManager: BlockManager) {
}
private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime)
+ shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
+ if (consolidateShuffleFiles) {
+ for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+ file.delete()
+ }
+ } else {
+ for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ }
+ }
+ })
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 632ff047d1..b5596dffd3 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -101,7 +101,7 @@ class StorageLevel private(
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
- result += (if (deserialized) "Deserialized " else "Serialized")
+ result += (if (deserialized) "Deserialized " else "Serialized ")
result += "%sx Replicated".format(replication)
result
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 1e4db4f66b..40734aab49 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.storage
import java.util.concurrent.atomic.AtomicLong
@@ -39,7 +56,7 @@ object StoragePerfTester {
def writeOutputBytes(mapId: Int, total: AtomicLong) = {
val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
- new KryoSerializer())
+ new KryoSerializer(sc.conf))
val writers = shuffle.writers
for (i <- 1 to recordsPerMap) {
writers(i % numOutputSplits).write(writeData)
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 860e680576..729ba2c550 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -22,6 +22,7 @@ import akka.actor._
import java.util.concurrent.ArrayBlockingQueue
import util.Random
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.{SparkConf, SparkContext}
/**
* This class tests the BlockManager and MemoryStore for thread safety and
@@ -91,11 +92,12 @@ private[spark] object ThreadingTest {
def main(args: Array[String]) {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
- val serializer = new KryoSerializer
+ val conf = new SparkConf()
+ val serializer = new KryoSerializer(conf)
val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index f1d86c0221..50dfdbdf5a 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -1,4 +1,4 @@
-/*
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -32,7 +32,7 @@ import org.apache.spark.util.Utils
/** Top level user interface for Spark */
private[spark] class SparkUI(sc: SparkContext) extends Logging {
val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName())
- val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt
+ val port = sc.conf.get("spark.ui.port", SparkUI.DEFAULT_PORT).toInt
var boundPort: Option[Int] = None
var server: Option[Server] = None
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index fcd1b518d0..6ba15187d9 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ui
import scala.util.Random
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.scheduler.SchedulingMode
@@ -27,25 +27,26 @@ import org.apache.spark.scheduler.SchedulingMode
/**
* Continuously generates jobs that expose various features of the WebUI (internal testing tool).
*
- * Usage: ./run spark.ui.UIWorkloadGenerator [master]
+ * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]
*/
private[spark] object UIWorkloadGenerator {
+
val NUM_PARTITIONS = 100
val INTER_JOB_WAIT_MS = 5000
def main(args: Array[String]) {
if (args.length < 2) {
- println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ println("usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
System.exit(1)
}
- val master = args(0)
- val schedulingMode = SchedulingMode.withName(args(1))
- val appName = "Spark UI Tester"
+ val conf = new SparkConf().setMaster(args(0)).setAppName("Spark UI tester")
+
+ val schedulingMode = SchedulingMode.withName(args(1))
if (schedulingMode == SchedulingMode.FAIR) {
- System.setProperty("spark.scheduler.mode", "FAIR")
+ conf.set("spark.scheduler.mode", "FAIR")
}
- val sc = new SparkContext(master, appName)
+ val sc = new SparkContext(conf)
def setProperties(s: String) = {
if(schedulingMode == SchedulingMode.FAIR) {
@@ -55,11 +56,11 @@ private[spark] object UIWorkloadGenerator {
}
val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS)
- def nextFloat() = (new Random()).nextFloat()
+ def nextFloat() = new Random().nextFloat()
val jobs = Seq[(String, () => Long)](
("Count", baseData.count),
- ("Cache and Count", baseData.map(x => x).cache.count),
+ ("Cache and Count", baseData.map(x => x).cache().count),
("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count),
("Entirely failed phase", baseData.map(x => throw new Exception).count),
("Partially failed phase", {
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
index c5bf2acc9e..88f41be8d3 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
@@ -48,12 +48,15 @@ private[spark] class EnvironmentUI(sc: SparkContext) {
def jvmTable =
UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true)
- val properties = System.getProperties.iterator.toSeq
- val classPathProperty = properties.find { case (k, v) =>
- k.contains("java.class.path")
+ val sparkProperties = sc.conf.getAll.sorted
+
+ val systemProperties = System.getProperties.iterator.toSeq
+ val classPathProperty = systemProperties.find { case (k, v) =>
+ k == "java.class.path"
}.getOrElse(("", ""))
- val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted
- val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted
+ val otherProperties = systemProperties.filter { case (k, v) =>
+ k != "java.class.path" && !k.startsWith("spark.")
+ }.sorted
val propertyHeaders = Seq("Name", "Value")
def propertyRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr>
@@ -63,7 +66,7 @@ private[spark] class EnvironmentUI(sc: SparkContext) {
UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true)
val classPathEntries = classPathProperty._2
- .split(System.getProperty("path.separator", ":"))
+ .split(sc.conf.get("path.separator", ":"))
.filterNot(e => e.isEmpty)
.map(e => (e, "System Classpath"))
val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index e596690bc3..a31a7e1d58 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_)
val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
- "Active tasks", "Failed tasks", "Complete tasks", "Total tasks")
+ "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read",
+ "Shuffle Write")
def execRow(kv: Seq[String]) = {
<tr>
@@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
<td>{kv(7)}</td>
<td>{kv(8)}</td>
<td>{kv(9)}</td>
+ <td>{Utils.msDurationToString(kv(10).toLong)}</td>
+ <td>{Utils.bytesToString(kv(11).toLong)}</td>
+ <td>{Utils.bytesToString(kv(12).toLong)}</td>
</tr>
}
@@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
val totalTasks = activeTasks + failedTasks + completedTasks
+ val totalDuration = listener.executorToDuration.getOrElse(execId, 0)
+ val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0)
+ val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0)
Seq(
execId,
@@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
activeTasks.toString,
failedTasks.toString,
completedTasks.toString,
- totalTasks.toString
+ totalTasks.toString,
+ totalDuration.toString,
+ totalShuffleRead.toString,
+ totalShuffleWrite.toString
)
}
@@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]()
val executorToTasksComplete = HashMap[String, Int]()
val executorToTasksFailed = HashMap[String, Int]()
+ val executorToDuration = HashMap[String, Long]()
+ val executorToShuffleRead = HashMap[String, Long]()
+ val executorToShuffleWrite = HashMap[String, Long]()
override def onTaskStart(taskStart: SparkListenerTaskStart) {
val eid = taskStart.taskInfo.executorId
@@ -140,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val eid = taskEnd.taskInfo.executorId
val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+ val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration
+ executorToDuration.put(eid, newDuration)
+
activeTasks -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
@@ -150,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
+
+ // update shuffle read/write
+ if (null != taskEnd.taskMetrics) {
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead =>
+ executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) +
+ shuffleRead.remoteBytesRead))
+
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite =>
+ executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) +
+ shuffleWrite.shuffleBytesWritten))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
index aea08ff81b..3c53e88380 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
@@ -15,27 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.rdd
+package org.apache.spark.ui.jobs
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: (TaskContext, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean
- ) extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
- override def compute(split: Partition, context: TaskContext) =
- f(context, firstParent[T].iterator(split, context))
+/** class for reporting aggregated metrics for each executors in stageUI */
+private[spark] class ExecutorSummary {
+ var taskTime : Long = 0
+ var failedTasks : Int = 0
+ var succeededTasks : Int = 0
+ var shuffleRead : Long = 0
+ var shuffleWrite : Long = 0
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
new file mode 100644
index 0000000000..0dd876480a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.Node
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.util.Utils
+import scala.collection.mutable
+
+/** Page showing executor summary */
+private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) {
+
+ val listener = parent.listener
+ val dateFmt = parent.dateFmt
+ val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR
+
+ def toNodeSeq(): Seq[Node] = {
+ listener.synchronized {
+ executorTable()
+ }
+ }
+
+ /** Special table which merges two header cells. */
+ private def executorTable[T](): Seq[Node] = {
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <th>Executor ID</th>
+ <th>Address</th>
+ <th>Task Time</th>
+ <th>Total Tasks</th>
+ <th>Failed Tasks</th>
+ <th>Succeeded Tasks</th>
+ <th>Shuffle Read</th>
+ <th>Shuffle Write</th>
+ </thead>
+ <tbody>
+ {createExecutorTable()}
+ </tbody>
+ </table>
+ }
+
+ private def createExecutorTable() : Seq[Node] = {
+ // make a executor-id -> address map
+ val executorIdToAddress = mutable.HashMap[String, String]()
+ val storageStatusList = parent.sc.getExecutorStorageStatus
+ for (statusId <- 0 until storageStatusList.size) {
+ val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId
+ val address = blockManagerId.hostPort
+ val executorId = blockManagerId.executorId
+ executorIdToAddress.put(executorId, address)
+ }
+
+ val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId)
+ executorIdToSummary match {
+ case Some(x) => {
+ x.toSeq.sortBy(_._1).map{
+ case (k,v) => {
+ <tr>
+ <td>{k}</td>
+ <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
+ <td>{parent.formatDuration(v.taskTime)}</td>
+ <td>{v.failedTasks + v.succeededTasks}</td>
+ <td>{v.failedTasks}</td>
+ <td>{v.succeededTasks}</td>
+ <td>{Utils.bytesToString(v.shuffleRead)}</td>
+ <td>{Utils.bytesToString(v.shuffleWrite)}</td>
+ </tr>
+ }
+ }
+ }
+ case _ => { Seq[Node]() }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 6b854740d6..b7b87250b9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler._
*/
private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener {
// How many stages to remember
- val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
+ val RETAINED_STAGES = sc.conf.get("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
val stageIdToPool = new HashMap[Int, String]()
@@ -57,10 +57,11 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val stageIdToTasksFailed = HashMap[Int, Int]()
val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
+ val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
- override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
val stage = stageCompleted.stage
poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
@@ -105,7 +106,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
-
+
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
@@ -124,8 +125,38 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
+
+ // create executor summary map if necessary
+ val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid,
+ op = new HashMap[String, ExecutorSummary]())
+ executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId,
+ op = new ExecutorSummary())
+
+ val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId)
+ executorSummary match {
+ case Some(y) => {
+ // first update failed-task, succeed-task
+ taskEnd.reason match {
+ case Success =>
+ y.succeededTasks += 1
+ case _ =>
+ y.failedTasks += 1
+ }
+
+ // update duration
+ y.taskTime += taskEnd.taskInfo.duration
+
+ Option(taskEnd.taskMetrics).foreach { taskMetrics =>
+ taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead }
+ taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten }
+ }
+ }
+ case _ => {}
+ }
+
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
+
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index e7eab374ad..c1ee2f3d00 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import akka.util.Duration
+import scala.concurrent.duration._
import java.text.SimpleDateFormat
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index fbd822867f..8dcfeacb60 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) {
var activeTime = 0L
listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished)
+
val summary =
<div>
<ul class="unstyled">
<li>
- <strong>CPU time: </strong>
+ <strong>Total task time across all tasks: </strong>
{parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
@@ -84,7 +86,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val taskHeaders: Seq[String] =
Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
- Seq("Duration", "GC Time") ++
+ Seq("Duration", "GC Time", "Result Ser Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
@@ -99,11 +101,43 @@ private[spark] class StagePage(parent: JobProgressUI) {
None
}
else {
+ val serializationTimes = validTasks.map{case (info, metrics, exception) =>
+ metrics.get.resultSerializationTime.toDouble}
+ val serializationQuantiles = "Result serialization time" +: Distribution(serializationTimes).get.getQuantiles().map(
+ ms => parent.formatDuration(ms.toLong))
+
val serviceTimes = validTasks.map{case (info, metrics, exception) =>
metrics.get.executorRunTime.toDouble}
val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
ms => parent.formatDuration(ms.toLong))
+ val gettingResultTimes = validTasks.map{case (info, metrics, exception) =>
+ if (info.gettingResultTime > 0) {
+ (info.finishTime - info.gettingResultTime).toDouble
+ } else {
+ 0.0
+ }
+ }
+ val gettingResultQuantiles = ("Time spent fetching task results" +:
+ Distribution(gettingResultTimes).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+ // The scheduler delay includes the network delay to send the task to the worker
+ // machine and to send back the result (but not the time to fetch the task result,
+ // if it needed to be fetched from the block manager on the worker).
+ val schedulerDelays = validTasks.map{case (info, metrics, exception) =>
+ val totalExecutionTime = {
+ if (info.gettingResultTime > 0) {
+ (info.gettingResultTime - info.launchTime).toDouble
+ } else {
+ (info.finishTime - info.launchTime).toDouble
+ }
+ }
+ totalExecutionTime - metrics.get.executorRunTime
+ }
+ val schedulerDelayQuantiles = ("Scheduler delay" +:
+ Distribution(schedulerDelays).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+
def getQuantileCols(data: Seq[Double]) =
Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong))
@@ -119,7 +153,11 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
- val listings: Seq[Seq[String]] = Seq(serviceQuantiles,
+ val listings: Seq[Seq[String]] = Seq(
+ serializationQuantiles,
+ serviceQuantiles,
+ gettingResultQuantiles,
+ schedulerDelayQuantiles,
if (hasShuffleRead) shuffleReadQuantiles else Nil,
if (hasShuffleWrite) shuffleWriteQuantiles else Nil)
@@ -128,11 +166,12 @@ private[spark] class StagePage(parent: JobProgressUI) {
def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr>
Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
-
+ val executorTable = new ExecutorTable(parent, stageId)
val content =
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
+ <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
<h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
@@ -151,22 +190,20 @@ private[spark] class StagePage(parent: JobProgressUI) {
val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration)
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+ val serializationTime = metrics.map(m => m.resultSerializationTime).getOrElse(0L)
- var shuffleReadSortable: String = ""
- var shuffleReadReadable: String = ""
- if (shuffleRead) {
- shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString()
- shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")
- }
+ val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}
+ val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
+ val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("")
- var shuffleWriteSortable: String = ""
- var shuffleWriteReadable: String = ""
- if (shuffleWrite) {
- shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString()
- shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")
- }
+ val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}
+ val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("")
+ val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("")
+
+ val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime}
+ val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("")
+ val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms =>
+ if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("")
<tr>
<td>{info.index}</td>
@@ -181,14 +218,17 @@ private[spark] class StagePage(parent: JobProgressUI) {
<td sorttable_customkey={gcTime.toString}>
{if (gcTime > 0) parent.formatDuration(gcTime) else ""}
</td>
+ <td sorttable_customkey={serializationTime.toString}>
+ {if (serializationTime > 0) parent.formatDuration(serializationTime) else ""}
+ </td>
{if (shuffleRead) {
<td sorttable_customkey={shuffleReadSortable}>
{shuffleReadReadable}
</td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}
+ <td sorttable_customkey={writeTimeSortable}>
+ {writeTimeReadable}
</td>
<td sorttable_customkey={shuffleWriteSortable}>
{shuffleWriteReadable}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 9ad6de3c6d..463d85dfd5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
{if (isFairScheduler) {<th>Pool Name</th>} else {}}
<th>Description</th>
<th>Submitted</th>
- <th>Duration</th>
+ <th>Task Time</th>
<th>Tasks: Succeeded/Total</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index 1d633d374a..39f422dd6b 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.storage
-import akka.util.Duration
+import scala.concurrent.duration._
import javax.servlet.http.HttpServletRequest
@@ -28,9 +28,6 @@ import org.apache.spark.ui.JettyUtils._
/** Web UI showing storage status of all RDD's in the given SparkContext. */
private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
- implicit val timeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
val indexPage = new IndexPage(this)
val rddPage = new RDDPage(this)
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index d4c5065c3f..3f009a8998 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -17,11 +17,14 @@
package org.apache.spark.util
-import akka.actor.{ActorSystem, ExtendedActorSystem}
+import scala.collection.JavaConversions.mapAsJavaMap
+import scala.concurrent.duration.{Duration, FiniteDuration}
+
+import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
import com.typesafe.config.ConfigFactory
-import akka.util.duration._
-import akka.remote.RemoteActorRefProvider
+import org.apache.log4j.{Level, Logger}
+import org.apache.spark.SparkConf
/**
* Various utility classes for working with Akka.
@@ -34,39 +37,77 @@ private[spark] object AkkaUtils {
*
* Note: the `name` parameter is important, as even if a client sends a message to right
* host + port, if the system name is incorrect, Akka will drop the message.
+ *
+ * If indestructible is set to true, the Actor System will continue running in the event
+ * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
- def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
- val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
- val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
- val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
- val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
- // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
- val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
-
- val akkaConf = ConfigFactory.parseString("""
- akka.daemonic = on
- akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
- akka.stdout-loglevel = "ERROR"
- akka.actor.provider = "akka.remote.RemoteActorRefProvider"
- akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
- akka.remote.netty.hostname = "%s"
- akka.remote.netty.port = %d
- akka.remote.netty.connection-timeout = %ds
- akka.remote.netty.message-frame-size = %d MiB
- akka.remote.netty.execution-pool-size = %d
- akka.actor.default-dispatcher.throughput = %d
- akka.remote.log-remote-lifecycle-events = %s
- akka.remote.netty.write-timeout = %ds
- """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- lifecycleEvents, akkaWriteTimeout))
-
- val actorSystem = ActorSystem(name, akkaConf)
-
- // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
- // hack because Akka doesn't let you figure out the port through the public API yet.
+ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
+ conf: SparkConf): (ActorSystem, Int) = {
+
+ val akkaThreads = conf.get("spark.akka.threads", "4").toInt
+ val akkaBatchSize = conf.get("spark.akka.batchSize", "15").toInt
+
+ val akkaTimeout = conf.get("spark.akka.timeout", "100").toInt
+
+ val akkaFrameSize = conf.get("spark.akka.frameSize", "10").toInt
+ val akkaLogLifecycleEvents = conf.get("spark.akka.logLifecycleEvents", "false").toBoolean
+ val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
+ if (!akkaLogLifecycleEvents) {
+ // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent.
+ // See: https://www.assembla.com/spaces/akka/tickets/3787#/
+ Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL))
+ }
+
+ val logAkkaConfig = if (conf.get("spark.akka.logAkkaConfig", "false").toBoolean) "on" else "off"
+
+ val akkaHeartBeatPauses = conf.get("spark.akka.heartbeat.pauses", "600").toInt
+ val akkaFailureDetector =
+ conf.get("spark.akka.failure-detector.threshold", "300.0").toDouble
+ val akkaHeartBeatInterval = conf.get("spark.akka.heartbeat.interval", "1000").toInt
+
+ val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback(
+ ConfigFactory.parseString(
+ s"""
+ |akka.daemonic = on
+ |akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
+ |akka.stdout-loglevel = "ERROR"
+ |akka.jvm-exit-on-fatal-error = off
+ |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
+ |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
+ |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
+ |akka.actor.provider = "akka.remote.RemoteActorRefProvider"
+ |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport"
+ |akka.remote.netty.tcp.hostname = "$host"
+ |akka.remote.netty.tcp.port = $port
+ |akka.remote.netty.tcp.tcp-nodelay = on
+ |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
+ |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
+ |akka.remote.netty.tcp.execution-pool-size = $akkaThreads
+ |akka.actor.default-dispatcher.throughput = $akkaBatchSize
+ |akka.log-config-on-start = $logAkkaConfig
+ |akka.remote.log-remote-lifecycle-events = $lifecycleEvents
+ |akka.log-dead-letters = $lifecycleEvents
+ |akka.log-dead-letters-during-shutdown = $lifecycleEvents
+ """.stripMargin))
+
+ val actorSystem = if (indestructible) {
+ IndestructibleActorSystem(name, akkaConf)
+ } else {
+ ActorSystem(name, akkaConf)
+ }
+
val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
- val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get
- return (actorSystem, boundPort)
+ val boundPort = provider.getDefaultAddress.port.get
+ (actorSystem, boundPort)
+ }
+
+ /** Returns the default Spark timeout to use for Akka ask operations. */
+ def askTimeout(conf: SparkConf): FiniteDuration = {
+ Duration.create(conf.get("spark.akka.askTimeout", "30").toLong, "seconds")
+ }
+
+ /** Returns the default Spark timeout to use for Akka remote actor lookup. */
+ def lookupTimeout(conf: SparkConf): FiniteDuration = {
+ Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
index 0b51c23f7b..a38329df03 100644
--- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -34,6 +34,8 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
override def iterator: Iterator[A] = underlying.iterator.asScala
+ override def size: Int = underlying.size
+
override def ++=(xs: TraversableOnce[A]): this.type = {
xs.foreach { this += _ }
this
diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
new file mode 100644
index 0000000000..bf71882ef7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Must be in akka.actor package as ActorSystemImpl is protected[akka].
+package akka.actor
+
+import scala.util.control.{ControlThrowable, NonFatal}
+
+import com.typesafe.config.Config
+
+/**
+ * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception.
+ * This is necessary as Spark Executors are allowed to recover from fatal exceptions
+ * (see [[org.apache.spark.executor.Executor]]).
+ */
+object IndestructibleActorSystem {
+ def apply(name: String, config: Config): ActorSystem =
+ apply(name, config, ActorSystem.findClassLoader())
+
+ def apply(name: String, config: Config, classLoader: ClassLoader): ActorSystem =
+ new IndestructibleActorSystemImpl(name, config, classLoader).start()
+}
+
+private[akka] class IndestructibleActorSystemImpl(
+ override val name: String,
+ applicationConfig: Config,
+ classLoader: ClassLoader)
+ extends ActorSystemImpl(name, applicationConfig, classLoader) {
+
+ protected override def uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = {
+ val fallbackHandler = super.uncaughtExceptionHandler
+
+ new Thread.UncaughtExceptionHandler() {
+ def uncaughtException(thread: Thread, cause: Throwable): Unit = {
+ if (isFatalError(cause) && !settings.JvmExitOnFatalError) {
+ log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " +
+ "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name)
+ //shutdown() //TODO make it configurable
+ } else {
+ fallbackHandler.uncaughtException(thread, cause)
+ }
+ }
+ }
+ }
+
+ def isFatalError(e: Throwable): Boolean = {
+ e match {
+ case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable =>
+ false
+ case _ =>
+ true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 67a7f87a5c..aa7f52cafb 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -18,16 +18,21 @@
package org.apache.spark.util
import java.util.{TimerTask, Timer}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, SparkContext, Logging}
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
-class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging {
+class MetadataCleaner(
+ cleanerType: MetadataCleanerType.MetadataCleanerType,
+ cleanupFunc: (Long) => Unit,
+ conf: SparkConf)
+ extends Logging
+{
val name = cleanerType.toString
- private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType)
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
@@ -55,8 +60,7 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
}
-object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
+object MetadataCleanerType extends Enumeration {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
@@ -66,22 +70,28 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext
def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString
}
+// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the
+// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
object MetadataCleaner {
+ def getDelaySeconds(conf: SparkConf) = {
+ conf.get("spark.cleaner.ttl", "3500").toInt
+ }
- // using only sys props for now : so that workers can also get to it while preserving earlier behavior.
- def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
-
- def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
- System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt
+ def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
+ {
+ conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString)
+ .toInt
}
- def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) {
- System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
+ def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType,
+ delay: Int)
+ {
+ conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
}
- def setDelaySeconds(delay: Int, resetAll: Boolean = true) {
+ def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) {
// override for all ?
- System.setProperty("spark.cleaner.ttl", delay.toString)
+ conf.set("spark.cleaner.ttl", delay.toString)
if (resetAll) {
for (cleanerType <- MetadataCleanerType.values) {
System.clearProperty(MetadataCleanerType.systemProperty(cleanerType))
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
new file mode 100644
index 0000000000..8b4e7c104c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.{Externalizable, ObjectOutput, ObjectInput}
+import com.clearspring.analytics.stream.cardinality.{ICardinality, HyperLogLog}
+
+/**
+ * A wrapper around [[com.clearspring.analytics.stream.cardinality.HyperLogLog]] that is serializable.
+ */
+private[spark]
+class SerializableHyperLogLog(var value: ICardinality) extends Externalizable {
+
+ def this() = this(null) // For deserialization
+
+ def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value))
+
+ def add[T](elem: T) = {
+ this.value.offer(elem)
+ this
+ }
+
+ def readExternal(in: ObjectInput) {
+ val byteLength = in.readInt()
+ val bytes = new Array[Byte](byteLength)
+ in.readFully(bytes)
+ value = HyperLogLog.Builder.build(bytes)
+ }
+
+ def writeExternal(out: ObjectOutput) {
+ val bytes = value.getBytes()
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index a25b37a2a9..bddb3bb735 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -30,10 +30,10 @@ import java.lang.management.ManagementFactory
import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.ints.IntOpenHashSet
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, SparkConf, SparkContext, Logging}
/**
- * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
+ * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
* memory-aware caches.
*
* Based on the following JavaWorld article:
@@ -89,9 +89,11 @@ private[spark] object SizeEstimator extends Logging {
classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil))
}
- private def getIsCompressedOops : Boolean = {
+ private def getIsCompressedOops: Boolean = {
+ // This is only used by tests to override the detection of compressed oops. The test
+ // actually uses a system property instead of a SparkConf, so we'll stick with that.
if (System.getProperty("spark.test.useCompressedOops") != null) {
- return System.getProperty("spark.test.useCompressedOops").toBoolean
+ return System.getProperty("spark.test.useCompressedOops").toBoolean
}
try {
@@ -103,7 +105,7 @@ private[spark] object SizeEstimator extends Logging {
val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
Class.forName("java.lang.String"))
- val bean = ManagementFactory.newPlatformMXBeanProxy(server,
+ val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, hotSpotMBeanClass)
// TODO: We could use reflection on the VMOption returned ?
return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
@@ -251,7 +253,7 @@ private[spark] object SizeEstimator extends Logging {
if (info != null) {
return info
}
-
+
val parent = getClassInfo(cls.getSuperclass)
var shellSize = parent.shellSize
var pointerFields = parent.pointerFields
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 277de2f8a6..181ae2fd45 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -85,7 +85,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
}
override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
}
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
@@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
def toMap: immutable.Map[A, B] = iterator.toMap
/**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`,
+ * calling the supplied function on each such entry before removing.
*/
- def clearOldValues(threshTime: Long) {
+ def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
+ while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue._2 < threshTime) {
+ f(entry.getKey, entry.getValue._1)
logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
}
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ clearOldValues(threshTime, (_, _) => ())
+ }
+
private def currentTime: Long = System.currentTimeMillis()
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a79e64e810..5f1253100b 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -22,10 +22,11 @@ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
-import scala.collection.JavaConversions._
import scala.io.Source
+import scala.reflect.ClassTag
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
@@ -35,14 +36,13 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging}
/**
* Various utility methods used by Spark.
*/
private[spark] object Utils extends Logging {
-
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -238,9 +238,9 @@ private[spark] object Utils extends Logging {
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File) {
+ def fetchFile(url: String, targetDir: File, conf: SparkConf) {
val filename = url.split("/").last
- val tempDir = getLocalDir
+ val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
@@ -310,8 +310,8 @@ private[spark] object Utils extends Logging {
* return a single directory, even though the spark.local.dir property might be a list of
* multiple paths.
*/
- def getLocalDir: String = {
- System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ def getLocalDir(conf: SparkConf): String = {
+ conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
}
/**
@@ -319,7 +319,7 @@ private[spark] object Utils extends Logging {
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
- def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = {
randomizeInPlace(seq.toArray)
}
@@ -396,13 +396,12 @@ private[spark] object Utils extends Logging {
InetAddress.getByName(address).getHostName
}
- def localHostPort(): String = {
- val retval = System.getProperty("spark.hostPort", null)
+ def localHostPort(conf: SparkConf): String = {
+ val retval = conf.get("spark.hostPort", null)
if (retval == null) {
logErrorWithStack("spark.hostPort not set but invoking localHostPort")
return localHostName()
}
-
retval
}
@@ -414,9 +413,12 @@ private[spark] object Utils extends Logging {
assert(hostPort.indexOf(':') != -1, message)
}
- // Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
- try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ try {
+ throw new Exception
+ } catch {
+ case ex: Exception => logError(msg, ex)
+ }
}
// Typically, this will be of order of number of nodes in cluster
@@ -836,7 +838,7 @@ private[spark] object Utils extends Logging {
}
}
- /**
+ /**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
* @param f function to be executed
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index 1e9faaa5a0..a7a6635dec 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect.ClassTag
/**
* A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
@@ -26,7 +27,7 @@ package org.apache.spark.util.collection
* Under the hood, it uses our OpenHashSet implementation.
*/
private[spark]
-class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
+class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
val keySet: OpenHashSet[K], var _values: Array[V])
extends Iterable[(K, V)]
with Serializable {
@@ -169,4 +170,3 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
_values(newPos) = _oldValues(oldPos)
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 36e2a05b9c..c908512b0f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect._
/**
* A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
@@ -36,7 +37,7 @@ package org.apache.spark.util.collection
* to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
*/
private[spark]
-class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
+class OpenHashSet[@specialized(Long, Int) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
extends Serializable {
@@ -62,14 +63,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
// throws:
// scala.tools.nsc.symtab.Types$TypeError: type mismatch;
// found : scala.reflect.AnyValManifest[Long]
- // required: scala.reflect.ClassManifest[Int]
+ // required: scala.reflect.ClassTag[Int]
// at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
// at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
// ...
- val mt = classManifest[T]
- if (mt == ClassManifest.Long) {
+ val mt = classTag[T]
+ if (mt == ClassTag.Long) {
(new LongHasher).asInstanceOf[Hasher[T]]
- } else if (mt == ClassManifest.Int) {
+ } else if (mt == ClassTag.Int) {
(new IntHasher).asInstanceOf[Hasher[T]]
} else {
new Hasher[T]
@@ -79,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
+ protected var _growThreshold = (loadFactor * _capacity).toInt
protected var _bitset = new BitSet(_capacity)
@@ -117,7 +119,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* @return The position where the key is placed, plus the highest order bit is set if the key
* exists previously.
*/
- def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+ def addWithoutResize(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ // This is a new key.
+ _data(pos) = k
+ _bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (_data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
/**
* Rehash the set if it is overloaded.
@@ -128,7 +152,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* to a new position (in the new data array).
*/
def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
- if (_size > loadFactor * _capacity) {
+ if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
}
@@ -181,37 +205,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
/**
- * Put an entry into the set. Return the position where the key is placed. In addition, the
- * highest bit in the returned position is set if the key exists prior to this put.
- *
- * This function assumes the data array has at least one empty slot.
- */
- private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
- val mask = data.length - 1
- var pos = hashcode(hasher.hash(k)) & mask
- var i = 1
- while (true) {
- if (!bitset.get(pos)) {
- // This is a new key.
- data(pos) = k
- bitset.set(pos)
- _size += 1
- return pos | NONEXISTENCE_MASK
- } else if (data(pos) == k) {
- // Found an existing key.
- return pos
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
- }
-
- /**
* Double the table's size and re-hash everything. We are not really using k, but it is declared
* so Scala compiler can specialize this method (which leads to calling the specialized version
* of putInto).
@@ -224,33 +217,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
*/
private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
val newCapacity = _capacity * 2
- require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-
allocateFunc(newCapacity)
- val newData = new Array[T](newCapacity)
val newBitset = new BitSet(newCapacity)
- var pos = 0
- _size = 0
- while (pos < _capacity) {
- if (_bitset.get(pos)) {
- val newPos = putInto(newBitset, newData, _data(pos))
- moveFunc(pos, newPos & POSITION_MASK)
+ val newData = new Array[T](newCapacity)
+ val newMask = newCapacity - 1
+
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (_bitset.get(oldPos)) {
+ val key = _data(oldPos)
+ var newPos = hashcode(hasher.hash(key)) & newMask
+ var i = 1
+ var keepGoing = true
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!newBitset.get(newPos)) {
+ // Inserting the key at newPos
+ newData(newPos) = key
+ newBitset.set(newPos)
+ moveFunc(oldPos, newPos)
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
+
_bitset = newBitset
_data = newData
_capacity = newCapacity
- _mask = newCapacity - 1
+ _mask = newMask
+ _growThreshold = (loadFactor * newCapacity).toInt
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
- private def hashcode(h: Int): Int = {
- it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
- }
+ private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index d6a3cdb405..1dc9f744e1 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect._
/**
* A fast hash map implementation for primitive, non-null keys. This hash map supports
@@ -26,8 +27,8 @@ package org.apache.spark.util.collection
* Under the hood, it uses our OpenHashSet implementation.
*/
private[spark]
-class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
- @specialized(Long, Int, Double) V: ClassManifest](
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
val keySet: OpenHashSet[K], var _values: Array[V])
extends Iterable[(K, V)]
with Serializable {
@@ -37,20 +38,19 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
*/
def this(initialCapacity: Int) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
-
+
/**
* Allocate an OpenHashMap with a default initial capacity, providing a true
* no-argument constructor.
*/
def this() = this(64)
-
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
- require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
private var _oldValues: Array[V] = null
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
index 20554f0aab..b84eb65c62 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -17,11 +17,13 @@
package org.apache.spark.util.collection
+import scala.reflect.ClassTag
+
/**
* An append-only, non-threadsafe, array-backed vector that is optimized for primitive types.
*/
private[spark]
-class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) {
private var _numElements = 0
private var _array: Array[V] = _