aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml4
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala1
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala11
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala44
-rw-r--r--core/src/main/scala/spark/RDD.scala70
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala15
-rw-r--r--core/src/main/scala/spark/SparkContext.scala5
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala12
-rw-r--r--core/src/main/scala/spark/Utils.scala18
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala11
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala1
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala35
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala76
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorkerFactory.scala95
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala1
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala12
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala27
-rw-r--r--core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala12
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala10
-rw-r--r--core/src/main/scala/spark/scheduler/JobLogger.scala306
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala50
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala747
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala748
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala227
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala172
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala34
-rw-r--r--core/src/main/scala/spark/util/BoundedPriorityQueue.scala45
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala26
-rw-r--r--core/src/test/scala/spark/FileSuite.scala46
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java45
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala21
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala39
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala19
-rw-r--r--core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/scheduler/JobLoggerSuite.scala105
-rw-r--r--core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala206
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala2
-rw-r--r--examples/pom.xml45
-rw-r--r--examples/src/main/scala/spark/examples/CassandraTest.scala196
-rw-r--r--examples/src/main/scala/spark/examples/HBaseTest.scala35
-rw-r--r--pom.xml14
-rw-r--r--project/SparkBuild.scala26
-rw-r--r--project/plugins.sbt2
-rw-r--r--python/pyspark/daemon.py158
-rw-r--r--python/pyspark/serializers.py4
-rw-r--r--python/pyspark/tests.py43
-rw-r--r--python/pyspark/worker.py55
-rw-r--r--repl/src/main/scala/spark/repl/ExecutorClassLoader.scala3
-rwxr-xr-xrun8
51 files changed, 2918 insertions, 983 deletions
diff --git a/core/pom.xml b/core/pom.xml
index d8687bf991..88f0ed70f3 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -32,8 +32,8 @@
<artifactId>compress-lzf</artifactId>
</dependency>
<dependency>
- <groupId>asm</groupId>
- <artifactId>asm-all</artifactId>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index e1fb02157a..3239f4c385 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
+ shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
index 50d6a1c5c9..d5e7132ff9 100644
--- a/core/src/main/scala/spark/ClosureCleaner.scala
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -5,8 +5,7 @@ import java.lang.reflect.Field
import scala.collection.mutable.Map
import scala.collection.mutable.Set
-import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
-import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
@@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
}
}
-private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
}
}
-private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 2b0e697337..fa4bbfc76f 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -10,6 +10,8 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
@@ -17,7 +19,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
@@ -185,11 +187,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
+ // groupByKey shouldn't use map side combine because map side combine does not
+ // reduce the amount of data shuffled and requires all map side data be inserted
+ // into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, mergeCombiners _, partitioner)
+ createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -516,6 +520,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress the result with the
+ * supplied codec.
+ */
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ }
+
+ /**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
@@ -576,6 +590,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress with the supplied codec.
+ */
+ def saveAsHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ codec: Class[_ <: CompressionCodec]) {
+ saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
+ new JobConf(self.context.hadoopConfiguration), Some(codec))
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile(
@@ -583,11 +611,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration),
+ codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
conf.set("mapred.output.format.class", outputFormatClass.getName)
+ for (c <- codec) {
+ conf.setCompressMapOutput(true)
+ conf.set("mapred.output.compress", "true")
+ conf.setMapOutputCompressorClass(c)
+ conf.set("mapred.output.compression.codec", c.getCanonicalName)
+ conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
+ }
conf.setOutputCommitter(classOf[FileOutputCommitter])
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index e6c0438d76..f336c2ea1e 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -7,12 +7,14 @@ import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import spark.broadcast.Broadcast
import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
@@ -35,6 +37,7 @@ import spark.rdd.ZippedPartitionsRDD2
import spark.rdd.ZippedPartitionsRDD3
import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
+import spark.util.BoundedPriorityQueue
import SparkContext._
@@ -114,6 +117,14 @@ abstract class RDD[T: ClassManifest](
this
}
+ /** User-defined generator of this RDD*/
+ var generator = Utils.getCallSiteInfo.firstUserClass
+
+ /** Reset generator*/
+ def setGenerator(_generator: String) = {
+ generator = _generator
+ }
+
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
@@ -352,13 +363,36 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD created by piping elements to a forked external process.
*/
- def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ def pipe(command: String, env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
/**
* Return an RDD created by piping elements to a forked external process.
- */
- def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
- new PipedRDD(this, command, env)
+ * The print behavior can be customized by providing two functions.
+ *
+ * @param command command to run in forked process.
+ * @param env environment variables to set.
+ * @param printPipeContext Before piping elements, this function is called as an oppotunity
+ * to pipe context data. Print line function (like out.println) will be
+ * passed as printPipeContext's parameter.
+ * @param printRDDElement Use this function to customize how to pipe elements. This function
+ * will be called with each RDD element as the 1st parameter, and the
+ * print line function (like out.println()) as the 2nd parameter.
+ * An example of pipe the RDD data of groupBy() in a streaming way,
+ * instead of constructing a huge String to concat all the elements:
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
+ * @return the result RDD
+ */
+ def pipe(
+ command: Seq[String],
+ env: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ new PipedRDD(this, command, env,
+ if (printPipeContext ne null) sc.clean(printPipeContext) else null,
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
* Return a new RDD by applying a function to each partition of this RDD.
@@ -723,6 +757,24 @@ abstract class RDD[T: ClassManifest](
}
/**
+ * Returns the top K elements from this RDD as defined by
+ * the specified implicit Ordering[T].
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ mapPartitions { items =>
+ val queue = new BoundedPriorityQueue[T](num)
+ queue ++= items
+ Iterator.single(queue)
+ }.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray
+ }
+
+ /**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
@@ -731,6 +783,14 @@ abstract class RDD[T: ClassManifest](
}
/**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
+ this.map(x => (NullWritable.get(), new Text(x.toString)))
+ .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
+ }
+
+ /**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) {
@@ -788,7 +848,7 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index 518034e07b..2911f9036e 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.FileOutputCommitter
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
@@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
- def saveAsSequenceFile(path: String) {
+ def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
val keyClass = getWritableClass[K]
@@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
+ val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
- self.saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
- self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
- self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
- self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
}
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index bc05d08fd6..70a9d7698c 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -49,7 +49,6 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
@@ -630,7 +629,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
@@ -713,7 +712,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index be1a04d619..7ccde2e818 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,5 +1,8 @@
package spark
+import collection.mutable
+import serializer.Serializer
+
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
@@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
+import spark.api.python.PythonWorkerFactory
/**
@@ -37,7 +41,10 @@ class SparkEnv (
// If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) {
+ private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
+
def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@@ -50,6 +57,11 @@ class SparkEnv (
actorSystem.awaitTermination()
}
+ def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ synchronized {
+ pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
+ }
+ }
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index ec15326014..f3621c6bee 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -116,8 +116,8 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory after " + maxAttempts +
- " attempts!")
+ throw new IOException("Failed to create a temp directory (under " + root + ") after " +
+ maxAttempts + " attempts!")
}
try {
dir = new File(root, "spark-" + UUID.randomUUID.toString)
@@ -522,13 +522,14 @@ private object Utils extends Logging {
execute(command, new File("."))
}
-
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
- def getSparkCallSite: String = {
+ def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@@ -540,6 +541,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
+ var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@@ -554,13 +556,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
finished = true
}
}
}
- "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 30084df4e2..76051597b6 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -6,6 +6,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
+ /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
+ }
+
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index eb81ed64cd..626b499454 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
-
}
object JavaRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 9b74d1226f..b555f2030a 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,9 +1,10 @@
package spark.api.java
-import java.util.{List => JList}
+import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
+import org.apache.hadoop.io.compress.CompressionCodec
import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
@@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+
+ /**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ rdd.saveAsTextFile(path, codec)
+
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@@ -351,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def toDebugString(): String = {
rdd.toDebugString
}
+
+ /**
+ * Returns the top K elements from this RDD as defined by
+ * the specified Comparator[T].
+ * @param num the number of top elements to return
+ * @param comp the comparator that defines the order
+ * @return an array of top elements
+ */
+ def top(num: Int, comp: Comparator[T]): JList[T] = {
+ import scala.collection.JavaConversions._
+ val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
+ val arr: java.util.Collection[T] = topElems.toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Returns the top K elements from this RDD using the
+ * natural ordering for T.
+ * @param num the number of top elements to return
+ * @return an array of top elements
+ */
+ def top(num: Int): JList[T] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
+ top(num, comp)
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 807119ca8c..63140cf37f 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -2,10 +2,9 @@ package spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
-import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
@@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
- envVars: java.util.Map[String, String],
+ envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
@@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
-
- val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
- // Add the environmental variables to the process.
- val currentEnvVars = pb.environment()
-
- for ((variable, value) <- envVars) {
- currentEnvVars.put(variable, value)
- }
- val proc = pb.start()
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val startTime = System.currentTimeMillis
val env = SparkEnv.get
-
- // Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + pythonExec) {
- override def run() {
- for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
- System.err.println(line)
- }
- }
- }.start()
+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val out = new PrintWriter(proc.getOutputStream)
- val dOut = new DataOutputStream(proc.getOutputStream)
+ val out = new PrintWriter(worker.getOutputStream)
+ val dOut = new DataOutputStream(worker.getOutputStream)
// Partition index
dOut.writeInt(split.index)
// sparkFilesDir
@@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
}
dOut.flush()
out.flush()
- proc.getOutputStream.close()
+ worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(proc.getInputStream)
+ val stream = new DataInputStream(worker.getInputStream)
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
- _nextObj = read()
+ if (hasNext) {
+ // FIXME: can deadlock if worker is waiting for us to
+ // respond to current message (currently irrelevant because
+ // output is shutdown before we read any input)
+ _nextObj = read()
+ }
obj
}
@@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
+ case -3 =>
+ // Timing data from worker
+ val bootTime = stream.readLong()
+ val initTime = stream.readLong()
+ val finishTime = stream.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
+ read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates; let's do that, breaking when we
+ // get a negative length record.
+ var len2 = stream.readInt()
+ while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
+ len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
- val exitStatus = proc.waitFor()
- if (exitStatus != 0) {
- throw new Exception("Subprocess exited with status " + exitStatus)
- }
- new Array[Byte](0)
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e => throw e
}
@@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
- case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -215,7 +211,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
- throw new Exception("Unexpected RDD type")
+ throw new SparkException("Unexpected RDD type")
}
}
diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
new file mode 100644
index 0000000000..8844411d73
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
@@ -0,0 +1,95 @@
+package spark.api.python
+
+import java.io.{DataInputStream, IOException}
+import java.net.{Socket, SocketException, InetAddress}
+
+import scala.collection.JavaConversions._
+
+import spark._
+
+private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+ extends Logging {
+ var daemon: Process = null
+ val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
+ var daemonPort: Int = 0
+
+ def create(): Socket = {
+ synchronized {
+ // Start the daemon if it hasn't been started
+ startDaemon()
+
+ // Attempt to connect, restart and retry once if it fails
+ try {
+ new Socket(daemonHost, daemonPort)
+ } catch {
+ case exc: SocketException => {
+ logWarning("Python daemon unexpectedly quit, attempting to restart")
+ stopDaemon()
+ startDaemon()
+ new Socket(daemonHost, daemonPort)
+ }
+ case e => throw e
+ }
+ }
+ }
+
+ def stop() {
+ stopDaemon()
+ }
+
+ private def startDaemon() {
+ synchronized {
+ // Is it already running?
+ if (daemon != null) {
+ return
+ }
+
+ try {
+ // Create and start the daemon
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ daemon = pb.start()
+ daemonPort = new DataInputStream(daemon.getInputStream).readInt()
+
+ // Redirect the stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ val in = daemon.getErrorStream
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+ } catch {
+ case e => {
+ stopDaemon()
+ throw e
+ }
+ }
+
+ // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
+ // detect our disappearance.
+ }
+ }
+
+ private def stopDaemon() {
+ synchronized {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 890938d48b..8bebfafce4 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -104,6 +104,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
+ m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index a7c56c2371..1dc13754f9 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -2,6 +2,11 @@ package spark.executor
class TaskMetrics extends Serializable {
/**
+ * Host's name the task runs on
+ */
+ var hostname: String = _
+
+ /**
* Time taken on the executor to deserialize this task
*/
var executorDeserializeTime: Int = _
@@ -34,9 +39,14 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
+ * Time when shuffle finishs
+ */
+ var shuffleFinishTime: Long = _
+
+ /**
* Total number of blocks fetched in a shuffle (remote or local)
*/
- var totalBlocksFetched : Int = _
+ var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 7599ba1a02..8966f9f86e 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
+import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
- * @param mapSideCombine flag indicating whether to merge values before shuffle step.
+ * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
+ * is on, Spark does an extra pass over the data on the map side to merge
+ * all values belonging to the same key together. This can reduce the amount
+ * of data shuffled if and only if the number of distinct keys is very small,
+ * and the ratio of key size to value size is also very small.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true,
+ val mapSideCombine: Boolean = false,
val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 962a1b21ad..c0baf43d43 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import spark.{RDD, SparkEnv, Partition, TaskContext}
+import spark.broadcast.Broadcast
/**
@@ -18,14 +19,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext}
class PipedRDD[T: ClassManifest](
prev: RDD[T],
command: Seq[String],
- envVars: Map[String, String])
+ envVars: Map[String, String],
+ printPipeContext: (String => Unit) => Unit,
+ printRDDElement: (T, String => Unit) => Unit)
extends RDD[String](prev) {
- def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
-
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
+ def this(
+ prev: RDD[T],
+ command: String,
+ envVars: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+
override def getPartitions: Array[Partition] = firstParent[T].partitions
@@ -52,8 +60,17 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
+
+ // input the pipe context firstly
+ if (printPipeContext != null) {
+ printPipeContext(out.println(_))
+ }
for (elem <- firstParent[T].iterator(split, context)) {
- out.println(elem)
+ if (printRDDElement != null) {
+ printRDDElement(elem, out.println(_))
+ } else {
+ out.println(elem)
+ }
}
out.close()
}
diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
index dd9f3c2680..b234428ab2 100644
--- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
@@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))
// Remove exact match and then do host local match.
- val otherNodePreferredLocations = rddSplitZip.map(x => {
- x._1.preferredLocations(x._2).map(hostPort => {
- val host = Utils.parseHostPort(hostPort)._1
-
- if (exactMatchLocations.contains(host)) null else host
- }).filter(_ != null)
- })
- val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y))
+ val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
+ val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
+ .reduce((x, y) => x.intersect(y))
+ val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }
otherNodeLocalLocations ++ exactMatchLocations
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 7feeb97542..f7d60be5db 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -298,6 +298,7 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
+ sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
idToActiveJob(runId) = job
activeJobs += job
resultStageToJob(finalStage) = job
@@ -311,6 +312,8 @@ class DAGScheduler(
handleExecutorLost(execId)
case completion: CompletionEvent =>
+ sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
+ completion.reason, completion.taskInfo, completion.taskMetrics)))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -321,6 +324,7 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
}
return true
}
@@ -468,6 +472,7 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
+ sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
@@ -522,6 +527,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@@ -662,7 +668,9 @@ class DAGScheduler(
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- job.listener.jobFailed(new SparkException("Job failed: " + reason))
+ val error = new SparkException("Job failed: " + reason)
+ job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
activeJobs -= job
resultStageToJob -= resultStage
}
diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala
new file mode 100644
index 0000000000..178bfaba3d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobLogger.scala
@@ -0,0 +1,306 @@
+package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index a65140b145..bac984b5c9 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -1,27 +1,59 @@
package spark.scheduler
+import java.util.Properties
import spark.scheduler.cluster.TaskInfo
import spark.util.Distribution
-import spark.{Utils, Logging}
+import spark.{Logging, SparkContext, TaskEndReason, Utils}
import spark.executor.TaskMetrics
-trait SparkListener {
- /**
- * called when a stage is completed, with information on the completed stage
- */
- def onStageCompleted(stageCompleted: StageCompleted)
-}
-
sealed trait SparkListenerEvents
+case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
+
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) extends SparkListenerEvents
+
+case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+ extends SparkListenerEvents
+
+case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
+ extends SparkListenerEvents
+
+trait SparkListener {
+ /**
+ * Called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted) { }
+
+ /**
+ * Called when a stage is submitted
+ */
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+
+ /**
+ * Called when a task ends
+ */
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+
+ /**
+ * Called when a job starts
+ */
+ def onJobStart(jobStart: SparkListenerJobStart) { }
+
+ /**
+ * Called when a job ends
+ */
+ def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+
+}
/**
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
- def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: StageCompleted) {
import spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 053d4b8e4a..3a0c29b27f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet)
+ val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
new file mode 100644
index 0000000000..d72b0bfc9f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -0,0 +1,747 @@
+package spark.scheduler.cluster
+
+import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import spark._
+import spark.scheduler._
+import spark.TaskState.TaskState
+import java.nio.ByteBuffer
+
+private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
+
+ // process local is expected to be used ONLY within tasksetmanager for now.
+ val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+
+ // Must not be the constraint.
+ assert (constraint != TaskLocality.PROCESS_LOCAL)
+
+ constraint match {
+ case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
+ case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
+ // For anything else, allow
+ case _ => true
+ }
+ }
+
+ def parse(str: String): TaskLocality = {
+ // better way to do this ?
+ try {
+ val retval = TaskLocality.withName(str)
+ // Must not specify PROCESS_LOCAL !
+ assert (retval != TaskLocality.PROCESS_LOCAL)
+
+ retval
+ } catch {
+ case nEx: NoSuchElementException => {
+ logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
+ // default to preserve earlier behavior
+ NODE_LOCAL
+ }
+ }
+ }
+}
+
+/**
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler.
+ */
+private[spark] class ClusterTaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet)
+ extends TaskSetManager
+ with Logging {
+
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = 4
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksFinished = 0
+
+ var weight = 1
+ var minShare = 0
+ var runningTasks = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent:Schedulable = null
+
+ // Last time when we launched a preferred task (for delay scheduling)
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ // List of pending tasks for each node (process local to container). These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node.
+ // Essentially, similar to pendingTasksForHostPort, except at host level
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node based on rack locality.
+ // Essentially, similar to pendingTasksForHost, except at rack level
+ private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List containing pending tasks with no locality preferences
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // List containing all pending tasks (also used as a stack, as above)
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the job fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+ // Map of recent exceptions (identified by string representation and
+ // top stack frame) to duplicate count (how many times the same
+ // exception has appeared) and time the full exception was
+ // printed. This should ideally be an LRU map that can drop old
+ // exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker generation and set it on all tasks
+ val generation = sched.mapOutputTracker.getGeneration
+ logDebug("Generation for " + taskSet.id + ": " + generation)
+ for (t <- tasks) {
+ t.generation = generation
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Note that it follows the hierarchy.
+ // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
+ // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
+ private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
+ taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
+
+ if (TaskLocality.PROCESS_LOCAL == taskLocality) {
+ // straight forward comparison ! Special case it.
+ val retval = new HashSet[String]()
+ scheduler.synchronized {
+ for (location <- _taskPreferredLocations) {
+ if (scheduler.isExecutorAliveOnHostPort(location)) {
+ retval += location
+ }
+ }
+ }
+
+ return retval
+ }
+
+ val taskPreferredLocations =
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ _taskPreferredLocations
+ } else {
+ assert (TaskLocality.RACK_LOCAL == taskLocality)
+ // Expand set to include all 'seen' rack local hosts.
+ // This works since container allocation/management happens within master - so any rack locality information is updated in msater.
+ // Best case effort, and maybe sort of kludge for now ... rework it later ?
+ val hosts = new HashSet[String]
+ _taskPreferredLocations.foreach(h => {
+ val rackOpt = scheduler.getRackForHost(h)
+ if (rackOpt.isDefined) {
+ val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
+ if (hostsOpt.isDefined) {
+ hosts ++= hostsOpt.get
+ }
+ }
+
+ // Ensure that irrespective of what scheduler says, host is always added !
+ hosts += h
+ })
+
+ hosts
+ }
+
+ val retval = new HashSet[String]
+ scheduler.synchronized {
+ for (prefLocation <- taskPreferredLocations) {
+ val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
+ if (aliveLocationsOpt.isDefined) {
+ retval ++= aliveLocationsOpt.get
+ }
+ }
+ }
+
+ retval
+ }
+
+ // Add a task to all the pending-task lists that it should be on.
+ private def addPendingTask(index: Int) {
+ // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
+ // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
+ val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
+ val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+
+ if (rackLocalLocations.size == 0) {
+ // Current impl ensures this.
+ assert (processLocalLocations.size == 0)
+ assert (hostLocalLocations.size == 0)
+ pendingTasksWithNoPrefs += index
+ } else {
+
+ // process local locality
+ for (hostPort <- processLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
+ hostPortList += index
+ }
+
+ // host locality (includes process local)
+ for (hostPort <- hostLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ hostList += index
+ }
+
+ // rack locality (includes process local and host local)
+ for (rackLocalHostPort <- rackLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(rackLocalHostPort)
+
+ val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
+ val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
+ list += index
+ }
+ }
+
+ allPendingTasks += index
+ }
+
+ // Return the pending tasks list for a given host port (process local), or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+ pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
+ }
+
+ // Return the pending tasks list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Return the pending tasks (rack level) list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Number of pending tasks for a given host Port (which would be process local)
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending tasks for a given host (which would be data local)
+ def numPendingTasksForHost(hostPort: String): Int = {
+ getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending rack local tasks for a given host
+ def numRackLocalPendingTasksForHost(hostPort: String): Int = {
+ getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+
+ // Dequeue a pending task from the given list and return its index.
+ // Return None if the list is empty.
+ // This method also cleans up any tasks in the list that have already
+ // been launched, since we want that to happen lazily.
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ // Return a speculative task for a given host if any are available. The task should not have an
+ // attempt running on this host, in case the host is slow. In addition, if locality is set, the
+ // task must have a preference for this host/rack/no preferred locations at all.
+ private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+
+ assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+
+ if (speculatableTasks.size > 0) {
+ val localTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
+ }
+
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
+ }
+
+ // check for rack locality
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ locations.contains(hostPort) && !attemptLocs.contains(hostPort)
+ }
+
+ if (rackTask != None) {
+ speculatableTasks -= rackTask.get
+ return rackTask
+ }
+ }
+
+ // Any task ...
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ // Check for attemptLocs also ?
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
+ }
+ }
+ return None
+ }
+
+ // Dequeue a pending task for a given node and return its index.
+ // If localOnly is set to false, allow non-local tasks as well.
+ private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+ val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
+ if (processLocalTask != None) {
+ return processLocalTask
+ }
+
+ val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
+ if (localTask != None) {
+ return localTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
+ if (rackLocalTask != None) {
+ return rackLocalTask
+ }
+ }
+
+ // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
+ // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
+ val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
+ if (noPrefTask != None) {
+ return noPrefTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ val nonLocalTask = findTaskFromList(allPendingTasks)
+ if (nonLocalTask != None) {
+ return nonLocalTask
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(hostPort, locality)
+ }
+
+ private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ Utils.checkHostPort(hostPort)
+
+ val locs = task.preferredLocations
+
+ locs.contains(hostPort)
+ }
+
+ private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ val locs = task.preferredLocations
+
+ // If no preference, consider it as host local
+ if (locs.isEmpty) return true
+
+ val host = Utils.parseHostPort(hostPort)._1
+ locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
+ }
+
+ // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
+ // This is true if either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
+
+ val locs = task.preferredLocations
+
+ val preferredRacks = new HashSet[String]()
+ for (preferredHost <- locs) {
+ val rack = sched.getRackForHost(preferredHost)
+ if (None != rack) preferredRacks += rack.get
+ }
+
+ if (preferredRacks.isEmpty) return false
+
+ val hostRack = sched.getRackForHost(hostPort)
+
+ return None != hostRack && preferredRacks.contains(hostRack.get)
+ }
+
+ // Respond to an offer of a single slave from the scheduler by finding a task
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ // If explicitly specified, use that
+ val locality = if (overrideLocality != null) overrideLocality else {
+ // expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
+ }
+
+ findTask(hostPort, locality) match {
+ case Some(index) => {
+ // Found a task; do some bookkeeping and return a Mesos task for it
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ val taskLocality =
+ if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
+ if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
+ if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
+ TaskLocality.ANY
+ val prefStr = taskLocality.toString
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, hostPort, prefStr))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val time = System.currentTimeMillis
+ val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ lastPreferredLaunchTime = time
+ }
+ // Serialize and return the task
+ val startTime = System.currentTimeMillis
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = System.currentTimeMillis - startTime
+ increaseRunningTasks(1)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskFinished(tid, state, serializedData)
+ case TaskState.LOST =>
+ taskLost(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskLost(tid, state, serializedData)
+ case TaskState.KILLED =>
+ taskLost(tid, state, serializedData)
+ case _ =>
+ }
+ }
+
+ def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markSuccessful()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ try {
+ val result = ser.deserialize[TaskResult[_]](serializedData)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread().getContextClassLoader
+ throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
+ case ex => throw ex
+ }
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markFailed()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ if (serializedData != null && serializedData.limit() > 0) {
+ val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ finished(index) = true
+ tasksFinished += 1
+ sched.taskSetFinished(this)
+ decreaseRunningTasks(runningTasks)
+ return
+
+ case taskResultTooBig: TaskResultTooBigFailure =>
+ logInfo("Loss was due to task %s result exceeding Akka frame size; " +
+ "aborting job".format(tid))
+ abort("Task %s result exceeded Akka frame size".format(tid))
+ return
+
+ case ef: ExceptionFailure =>
+ val key = ef.description
+ val now = System.currentTimeMillis
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case _ => {}
+ }
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (state == TaskState.FAILED || state == TaskState.LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.listener.taskSetFailed(taskSet, message)
+ decreaseRunningTasks(runningTasks)
+ sched.taskSetFinished(this)
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def removeSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ override def executorLost(execId: String, hostPort: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // If some task has preferred locations only on hostname, and there are no more executors there,
+ // put it in the no-prefs list to avoid the wait from delay scheduling
+
+ // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
+ // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
+ // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
+ // there is no host local node for the task (not if there is no process local node for the task)
+ for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
+ // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ taskLost(tid, TaskState.KILLED, null)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the ClusterScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = System.currentTimeMillis()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.hostPort, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksFinished < numTasks
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index f1c6266bac..b4dd75d90f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -1,747 +1,17 @@
package spark.scheduler.cluster
-import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
-
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
-private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
-
- // process local is expected to be used ONLY within tasksetmanager for now.
- val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
-
- type TaskLocality = Value
-
- def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
-
- // Must not be the constraint.
- assert (constraint != TaskLocality.PROCESS_LOCAL)
-
- constraint match {
- case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
- case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
- // For anything else, allow
- case _ => true
- }
- }
-
- def parse(str: String): TaskLocality = {
- // better way to do this ?
- try {
- val retval = TaskLocality.withName(str)
- // Must not specify PROCESS_LOCAL !
- assert (retval != TaskLocality.PROCESS_LOCAL)
-
- retval
- } catch {
- case nEx: NoSuchElementException => {
- logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
- // default to preserve earlier behavior
- NODE_LOCAL
- }
- }
- }
-}
-
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler.
- */
-private[spark] class TaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet)
- extends Schedulable
- with Logging {
-
- // Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
-
- // CPUs to request per task
- val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = 4
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val ser = SparkEnv.get.closureSerializer.newInstance()
-
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksFinished = 0
-
- var weight = 1
- var minShare = 0
- var runningTasks = 0
- var priority = taskSet.priority
- var stageId = taskSet.stageId
- var name = "TaskSet_"+taskSet.stageId.toString
- var parent:Schedulable = null
-
- // Last time when we launched a preferred task (for delay scheduling)
- var lastPreferredLaunchTime = System.currentTimeMillis
-
- // List of pending tasks for each node (process local to container). These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
-
- // List of pending tasks for each node.
- // Essentially, similar to pendingTasksForHostPort, except at host level
- private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // List of pending tasks for each node based on rack locality.
- // Essentially, similar to pendingTasksForHost, except at rack level
- private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // List containing pending tasks with no locality preferences
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // List containing all pending tasks (also used as a stack, as above)
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the job fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
- // Map of recent exceptions (identified by string representation and
- // top stack frame) to duplicate count (how many times the same
- // exception has appeared) and time the full exception was
- // printed. This should ideally be an LRU map that can drop old
- // exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker generation and set it on all tasks
- val generation = sched.mapOutputTracker.getGeneration
- logDebug("Generation for " + taskSet.id + ": " + generation)
- for (t <- tasks) {
- t.generation = generation
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Note that it follows the hierarchy.
- // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
- // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
- private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
- taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
-
- if (TaskLocality.PROCESS_LOCAL == taskLocality) {
- // straight forward comparison ! Special case it.
- val retval = new HashSet[String]()
- scheduler.synchronized {
- for (location <- _taskPreferredLocations) {
- if (scheduler.isExecutorAliveOnHostPort(location)) {
- retval += location
- }
- }
- }
-
- return retval
- }
-
- val taskPreferredLocations =
- if (TaskLocality.NODE_LOCAL == taskLocality) {
- _taskPreferredLocations
- } else {
- assert (TaskLocality.RACK_LOCAL == taskLocality)
- // Expand set to include all 'seen' rack local hosts.
- // This works since container allocation/management happens within master - so any rack locality information is updated in msater.
- // Best case effort, and maybe sort of kludge for now ... rework it later ?
- val hosts = new HashSet[String]
- _taskPreferredLocations.foreach(h => {
- val rackOpt = scheduler.getRackForHost(h)
- if (rackOpt.isDefined) {
- val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
- if (hostsOpt.isDefined) {
- hosts ++= hostsOpt.get
- }
- }
-
- // Ensure that irrespective of what scheduler says, host is always added !
- hosts += h
- })
-
- hosts
- }
-
- val retval = new HashSet[String]
- scheduler.synchronized {
- for (prefLocation <- taskPreferredLocations) {
- val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
- if (aliveLocationsOpt.isDefined) {
- retval ++= aliveLocationsOpt.get
- }
- }
- }
-
- retval
- }
-
- // Add a task to all the pending-task lists that it should be on.
- private def addPendingTask(index: Int) {
- // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
- // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
- val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
- val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
- val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
-
- if (rackLocalLocations.size == 0) {
- // Current impl ensures this.
- assert (processLocalLocations.size == 0)
- assert (hostLocalLocations.size == 0)
- pendingTasksWithNoPrefs += index
- } else {
-
- // process local locality
- for (hostPort <- processLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
-
- val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
- hostPortList += index
- }
-
- // host locality (includes process local)
- for (hostPort <- hostLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
-
- val host = Utils.parseHostPort(hostPort)._1
- val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
- hostList += index
- }
-
- // rack locality (includes process local and host local)
- for (rackLocalHostPort <- rackLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(rackLocalHostPort)
-
- val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
- val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
- list += index
- }
- }
-
- allPendingTasks += index
- }
-
- // Return the pending tasks list for a given host port (process local), or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
- pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
- }
-
- // Return the pending tasks list for a given host, or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
- val host = Utils.parseHostPort(hostPort)._1
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- // Return the pending tasks (rack level) list for a given host, or an empty list if
- // there is no map entry for that host
- private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
- val host = Utils.parseHostPort(hostPort)._1
- pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- // Number of pending tasks for a given host Port (which would be process local)
- def numPendingTasksForHostPort(hostPort: String): Int = {
- getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
- }
-
- // Number of pending tasks for a given host (which would be data local)
- def numPendingTasksForHost(hostPort: String): Int = {
- getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
- }
-
- // Number of pending rack local tasks for a given host
- def numRackLocalPendingTasksForHost(hostPort: String): Int = {
- getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
- }
-
-
- // Dequeue a pending task from the given list and return its index.
- // Return None if the list is empty.
- // This method also cleans up any tasks in the list that have already
- // been launched, since we want that to happen lazily.
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- // Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if locality is set, the
- // task must have a preference for this host/rack/no preferred locations at all.
- private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
-
- assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
-
- if (speculatableTasks.size > 0) {
- val localTask = speculatableTasks.find {
- index =>
- val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
- val attemptLocs = taskAttempts(index).map(_.hostPort)
- (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
- }
-
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
- }
-
- // check for rack locality
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- val rackTask = speculatableTasks.find {
- index =>
- val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
- val attemptLocs = taskAttempts(index).map(_.hostPort)
- locations.contains(hostPort) && !attemptLocs.contains(hostPort)
- }
-
- if (rackTask != None) {
- speculatableTasks -= rackTask.get
- return rackTask
- }
- }
-
- // Any task ...
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- // Check for attemptLocs also ?
- val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
- }
- }
- }
- return None
- }
-
- // Dequeue a pending task for a given node and return its index.
- // If localOnly is set to false, allow non-local tasks as well.
- private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
- val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
- if (processLocalTask != None) {
- return processLocalTask
- }
-
- val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
- if (localTask != None) {
- return localTask
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
- if (rackLocalTask != None) {
- return rackLocalTask
- }
- }
-
- // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
- // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
- val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
- if (noPrefTask != None) {
- return noPrefTask
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- val nonLocalTask = findTaskFromList(allPendingTasks)
- if (nonLocalTask != None) {
- return nonLocalTask
- }
- }
-
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(hostPort, locality)
- }
-
- private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
- Utils.checkHostPort(hostPort)
-
- val locs = task.preferredLocations
-
- locs.contains(hostPort)
- }
-
- private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
- val locs = task.preferredLocations
-
- // If no preference, consider it as host local
- if (locs.isEmpty) return true
-
- val host = Utils.parseHostPort(hostPort)._1
- locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
- }
-
- // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
- // This is true if either the task has preferred locations and this host is one, or it has
- // no preferred locations (in which we still count the launch as preferred).
- private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
-
- val locs = task.preferredLocations
-
- val preferredRacks = new HashSet[String]()
- for (preferredHost <- locs) {
- val rack = sched.getRackForHost(preferredHost)
- if (None != rack) preferredRacks += rack.get
- }
-
- if (preferredRacks.isEmpty) return false
-
- val hostRack = sched.getRackForHost(hostPort)
-
- return None != hostRack && preferredRacks.contains(hostRack.get)
- }
-
- // Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
-
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
- // If explicitly specified, use that
- val locality = if (overrideLocality != null) overrideLocality else {
- // expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
- val time = System.currentTimeMillis
- if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
- }
-
- findTask(hostPort, locality) match {
- case Some(index) => {
- // Found a task; do some bookkeeping and return a Mesos task for it
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- val taskLocality =
- if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
- if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
- if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
- TaskLocality.ANY
- val prefStr = taskLocality.toString
- logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, hostPort, prefStr))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val time = System.currentTimeMillis
- val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- if (TaskLocality.NODE_LOCAL == taskLocality) {
- lastPreferredLaunchTime = time
- }
- // Serialize and return the task
- val startTime = System.currentTimeMillis
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = System.currentTimeMillis - startTime
- increaseRunningTasks(1)
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- state match {
- case TaskState.FINISHED =>
- taskFinished(tid, state, serializedData)
- case TaskState.LOST =>
- taskLost(tid, state, serializedData)
- case TaskState.FAILED =>
- taskLost(tid, state, serializedData)
- case TaskState.KILLED =>
- taskLost(tid, state, serializedData)
- case _ =>
- }
- }
-
- def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markSuccessful()
- decreaseRunningTasks(1)
- if (!finished(index)) {
- tasksFinished += 1
- logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
- tid, info.duration, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- try {
- val result = ser.deserialize[TaskResult[_]](serializedData)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
- } catch {
- case cnf: ClassNotFoundException =>
- val loader = Thread.currentThread().getContextClassLoader
- throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
- case ex => throw ex
- }
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- if (tasksFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markFailed()
- decreaseRunningTasks(1)
- if (!finished(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- if (serializedData != null && serializedData.limit() > 0) {
- val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
- reason match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- finished(index) = true
- tasksFinished += 1
- sched.taskSetFinished(this)
- decreaseRunningTasks(runningTasks)
- return
-
- case taskResultTooBig: TaskResultTooBigFailure =>
- logInfo("Loss was due to task %s result exceeding Akka frame size;" +
- "aborting job".format(tid))
- abort("Task %s result exceeded Akka frame size".format(tid))
- return
-
- case ef: ExceptionFailure =>
- val key = ef.description
- val now = System.currentTimeMillis
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
- }
-
- case _ => {}
- }
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (state == TaskState.FAILED || state == TaskState.LOST) {
- numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
- decreaseRunningTasks(runningTasks)
- sched.taskSetFinished(this)
- }
-
- override def increaseRunningTasks(taskNum: Int) {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
- }
- }
-
- override def decreaseRunningTasks(taskNum: Int) {
- runningTasks -= taskNum
- if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
- }
- }
-
- //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def addSchedulable(schedulable:Schedulable) {
- //nothing
- }
-
- override def removeSchedulable(schedulable:Schedulable) {
- //nothing
- }
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- override def executorLost(execId: String, hostPort: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
-
- // If some task has preferred locations only on hostname, and there are no more executors there,
- // put it in the no-prefs list to avoid the wait from delay scheduling
-
- // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
- // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
- // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
- // there is no host local node for the task (not if there is no process local node for the task)
- for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
- // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
- val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
- }
-
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
- copiesRunning(index) -= 1
- tasksFinished -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- taskLost(tid, TaskState.KILLED, null)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- override def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
- val time = System.currentTimeMillis()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.hostPort, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
-
- override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksFinished < numTasks
- }
+private[spark] trait TaskSetManager extends Schedulable {
+ def taskSet: TaskSet
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
+ overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
+ def numPendingTasksForHostPort(hostPort: String): Int
+ def numRackLocalPendingTasksForHost(hostPort :String): Int
+ def numPendingTasksForHost(hostPort: String): Int
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
+ def error(message: String)
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 37a67f9b1b..93d4318b29 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -2,19 +2,50 @@ package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
import spark._
+import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
-import spark.scheduler.cluster.{TaskLocality, TaskInfo}
+import spark.scheduler.cluster._
+import akka.actor._
/**
- * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
+
+private[spark] case class LocalReviveOffers()
+private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+ def receive = {
+ case LocalReviveOffers =>
+ launchTask(localScheduler.resourceOffer(freeCores))
+ case LocalStatusUpdate(taskId, state, serializeData) =>
+ freeCores += 1
+ localScheduler.statusUpdate(taskId, state, serializeData)
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ def launchTask(tasks : Seq[TaskDescription]) {
+ for (task <- tasks) {
+ freeCores -= 1
+ localScheduler.threadPool.submit(new Runnable {
+ def run() {
+ localScheduler.runTask(task.taskId,task.serializedTask)
+ }
+ })
+ }
+ }
+}
+
+private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -30,89 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
- // TODO: Need to take into account stage priority in scheduling
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ val taskIdToTaskSetId = new HashMap[Long, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ //default scheduler is FIFO
+ val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
+ //temporarily set rootPool name to empty
+ rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case "FIFO" =>
+ new FIFOSchedulableBuilder(rootPool)
+ case "FAIR" =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
- override def start() { }
+ localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
+ }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
- val tasks = taskSet.tasks
- val failCount = new Array[Int](tasks.size)
-
- def submitTask(task: Task[_], idInJob: Int) {
- val myAttemptId = attemptId.getAndIncrement()
- threadPool.submit(new Runnable {
- def run() {
- runTask(task, idInJob, myAttemptId)
- }
- })
+ synchronized {
+ var manager = new LocalTaskSetManager(this, taskSet)
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ activeTaskSets(taskSet.id) = manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ localActor ! LocalReviveOffers
}
+ }
+
+ def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
+ synchronized {
+ var freeCpuCores = freeCores
+ val tasks = new ArrayBuffer[TaskDescription](freeCores)
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
- def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running " + task)
- val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
- logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserStart = System.currentTimeMillis()
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- val deserTime = System.currentTimeMillis() - deserStart
-
- // Run it
- val result: Any = deserializedTask.run(attemptId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = ser.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = ser.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- logInfo("Finished " + task)
- info.markSuccessful()
- deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
-
- // If the threadpool has not already been shutdown, notify DAGScheduler
- if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
- } catch {
- case t: Throwable => {
- logError("Exception in task " + idInJob, t)
- failCount.synchronized {
- failCount(idInJob) += 1
- if (failCount(idInJob) <= maxFailures) {
- submitTask(task, idInJob)
- } else {
- // TODO: Do something nicer here to return all the way to the user
- if (!Thread.currentThread().isInterrupted) {
- val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
- listener.taskEnded(task, failure, null, null, info, null)
- }
+ var launchTask = false
+ for (manager <- sortedTaskSetQueue) {
+ do {
+ launchTask = false
+ manager.slaveOffer(null,null,freeCpuCores) match {
+ case Some(task) =>
+ tasks += task
+ taskIdToTaskSetId(task.taskId) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += task.taskId
+ freeCpuCores -= 1
+ launchTask = true
+ case None => {}
}
- }
- }
+ } while(launchTask)
}
+ return tasks
}
+ }
- for ((task, i) <- tasks.zipWithIndex) {
- submitTask(task, i)
+ def taskSetFinished(manager: TaskSetManager) {
+ synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds -= manager.taskSet.id
+ }
+ }
+
+ def runTask(taskId: Long, bytes: ByteBuffer) {
+ logInfo("Running " + taskId)
+ val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ // Set the Spark execution environment for the worker thread
+ SparkEnv.set(env)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
+ // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
+ // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserStart = System.currentTimeMillis()
+ val deserializedTask = ser.deserialize[Task[_]](
+ taskBytes, Thread.currentThread.getContextClassLoader)
+ val deserTime = System.currentTimeMillis() - deserStart
+
+ // Run it
+ val result: Any = deserializedTask.run(taskId)
+
+ // Serialize and deserialize the result to emulate what the Mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val serResult = ser.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = ser.deserialize[Any](serResult)
+ val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
+ ser.serialize(Accumulators.values))
+ logInfo("Finished " + taskId)
+ deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
+
+ val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val serializedResult = ser.serialize(taskResult)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
+ } catch {
+ case t: Throwable => {
+ val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
+ }
}
}
@@ -128,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
+
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
@@ -143,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
}
- override def stop() {
+ def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
+ synchronized {
+ val taskSetId = taskIdToTaskSetId(taskId)
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+ taskSetManager.statusUpdate(taskId, state, serializedData)
+ }
+ }
+
+ override def stop() {
threadPool.shutdownNow()
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
new file mode 100644
index 0000000000..70b69bb26f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -0,0 +1,172 @@
+package spark.scheduler.local
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import spark.scheduler.cluster._
+
+private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
+ var parent: Schedulable = null
+ var weight: Int = 1
+ var minShare: Int = 0
+ var runningTasks: Int = 0
+ var priority: Int = taskSet.priority
+ var stageId: Int = taskSet.stageId
+ var name: String = "TaskSet_"+taskSet.stageId.toString
+
+
+ var failCount = new Array[Int](taskSet.tasks.size)
+ val taskInfos = new HashMap[Long, TaskInfo]
+ val numTasks = taskSet.tasks.size
+ var numFinished = 0
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val MAX_TASK_FAILURES = sched.maxFailures
+
+ def increaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ def decreaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ def addSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def removeSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ def executorLost(executorId: String, host: String): Unit = {
+ //nothing
+ }
+
+ def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ def hasPendingTasks(): Boolean = {
+ return true
+ }
+
+ def findTask(): Option[Int] = {
+ for (i <- 0 to numTasks-1) {
+ if (copiesRunning(i) == 0 && !finished(i)) {
+ return Some(i)
+ }
+ }
+ return None
+ }
+
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ SparkEnv.set(sched.env)
+ logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
+ if (availableCpus > 0 && numFinished < numTasks) {
+ findTask() match {
+ case Some(index) =>
+ val taskId = sched.attemptId.getAndIncrement()
+ val task = taskSet.tasks(index)
+ val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ taskInfos(taskId) = info
+ val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ copiesRunning(index) += 1
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(taskId, null, taskName, bytes))
+ case None => {}
+ }
+ }
+ return None
+ }
+
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ return 0
+ }
+
+ def numRackLocalPendingTasksForHost(hostPort :String): Int = {
+ return 0
+ }
+
+ def numPendingTasksForHost(hostPort: String): Int = {
+ return 0
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskEnded(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskFailed(tid, state, serializedData)
+ case _ => {}
+ }
+ }
+
+ def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markSuccessful()
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ numFinished += 1
+ decreaseRunningTasks(1)
+ finished(index) = true
+ if (numFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ }
+
+ def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markFailed()
+ decreaseRunningTasks(1)
+ val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
+ if (!finished(index)) {
+ copiesRunning(index) -= 1
+ numFailures(index) += 1
+ val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
+ decreaseRunningTasks(runningTasks)
+ sched.listener.taskSetFailed(taskSet, errorMessage)
+ // need to delete failed Taskset from schedule queue
+ sched.taskSetFinished(this)
+ }
+ }
+ }
+
+ def error(message: String) {
+ }
+}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 15ab840155..da859eebcb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -96,15 +96,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def size(): Long = lastValidPosition
}
- val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
- var shuffleSender : ShuffleSender = null
+ private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
- val localDirs = createLocalDirs()
- val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
@@ -113,7 +113,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
-
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@@ -249,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map(rootDir => {
- var foundLocalDir: Boolean = false
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
@@ -265,7 +264,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
} catch {
case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
@@ -275,7 +274,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
logInfo("Created local directory at " + localDir)
localDir
- })
+ }
}
private def addShutdownHook() {
@@ -283,15 +282,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
- try {
- localDirs.foreach { localDir =>
+ localDirs.foreach { localDir =>
+ try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
}
- if (shuffleSender != null) {
- shuffleSender.stop
- }
- } catch {
- case t: Throwable => logError("Exception while deleting local spark dirs", t)
+ }
+ if (shuffleSender != null) {
+ shuffleSender.stop
}
}
})
diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
new file mode 100644
index 0000000000..4bc5db8bb7
--- /dev/null
+++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
@@ -0,0 +1,45 @@
+package spark.util
+
+import java.io.Serializable
+import java.util.{PriorityQueue => JPriorityQueue}
+import scala.collection.generic.Growable
+import scala.collection.JavaConverters._
+
+/**
+ * Bounded priority queue. This class wraps the original PriorityQueue
+ * class and modifies it such that only the top K elements are retained.
+ * The top K elements are defined by an implicit Ordering[A].
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
+ extends Iterable[A] with Growable[A] with Serializable {
+
+ private val underlying = new JPriorityQueue[A](maxSize, ord)
+
+ override def iterator: Iterator[A] = underlying.iterator.asScala
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ if (size < maxSize) underlying.offer(elem)
+ else maybeReplaceLowest(elem)
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ord.gt(a, head)) {
+ underlying.poll()
+ underlying.offer(a)
+ } else false
+ }
+}
+
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
index 5f80180339..2b980340b7 100644
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
- val delta = other.mu - mu
- if (other.n * 10 < n) {
- mu = mu + (delta * other.n) / (n + other.n)
- } else if (n * 10 < other.n) {
- mu = other.mu - (delta * n) / (n + other.n)
- } else {
- mu = (mu * n + other.mu * other.n) / (n + other.n)
+ if (n == 0) {
+ mu = other.mu
+ m2 = other.m2
+ n = other.n
+ } else if (other.n != 0) {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
}
- m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
- n += other.n
- this
+ this
}
}
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 91b48c7456..e61ff7793d 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -7,6 +7,8 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.hadoop.io._
+import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
+
import SparkContext._
@@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
}
+ test("text files (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize("a" * 10000, 1)
+ data.saveAsTextFile(normalDir)
+ data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.textFile(normalDir).collect
+ assert(normalContent === Array.fill(10000)("a"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.textFile(compressedOutputDir).collect
+ assert(compressedContent === Array.fill(10000)("a"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFiles") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
@@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
+ test("SequenceFile (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
+ data.saveAsSequenceFile(normalDir)
+ data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.sequenceFile[String, String](normalDir).collect
+ assert(normalContent === Array.fill(100)("abc", "abc"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
+ assert(compressedContent === Array.fill(100)("abc", "abc"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFile with writable key") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 93bb69b41c..d306124fca 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -8,6 +8,7 @@ import java.util.*;
import scala.Tuple2;
import com.google.common.base.Charsets;
+import org.apache.hadoop.io.compress.DefaultCodec;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@@ -474,6 +475,19 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void textFilesCompressed() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir, DefaultCodec.class);
+
+ // Try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
String outputDir = new File(tempDir, "output").getAbsolutePath();
@@ -620,6 +634,37 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void hadoopFileCompressed() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
+ DefaultCodec.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
public void zip() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 60db759c25..16f93e71a3 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,10 +1,10 @@
package spark
import org.scalatest.FunSuite
-
import scala.collection.mutable.ArrayBuffer
-
import SparkContext._
+import spark.util.StatCounter
+import scala.math.abs
class PartitioningSuite extends FunSuite with LocalSparkContext {
@@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
+
+ test("Zero-length partitions should be correctly handled") {
+ // Create RDD with some consecutive empty partitions (including the "first" one)
+ sc = new SparkContext("local", "test")
+ val rdd: RDD[Double] = sc
+ .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
+ .filter(_ >= 0.0)
+
+ // Run the partitions, including the consecutive empty ones, through StatCounter
+ val stats: StatCounter = rdd.stats();
+ assert(abs(6.0 - stats.sum) < 0.01);
+ assert(abs(6.0/2 - rdd.mean) < 0.01);
+ assert(abs(1.0 - rdd.variance) < 0.01);
+ assert(abs(1.0 - rdd.stdev) < 0.01);
+
+ // Add other tests here for classes that should be able to handle empty partitions correctly
+ }
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index a6344edf8f..ed075f93ec 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -19,6 +19,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
assert(c(3) === "4")
}
+ test("advanced pipe") {
+ sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Int, f: String=> Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str=>str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ }
+
test("pipe with env variable") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 3f69e99780..67f3332d44 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+
+ test("top with predefined ordering") {
+ sc = new SparkContext("local", "test")
+ val nums = Array.range(1, 100000)
+ val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
+ val topK = ints.top(5)
+ assert(topK.size === 5)
+ assert(topK.sorted === nums.sorted.takeRight(5))
+ }
+
+ test("top with custom ordering") {
+ sc = new SparkContext("local", "test")
+ val words = Vector("a", "b", "c", "d")
+ implicit val ord = implicitly[Ordering[String]].reverse
+ val rdd = sc.makeRDD(words, 2)
+ val topK = rdd.top(2)
+ assert(topK.size === 2)
+ assert(topK.sorted === Array("b", "a"))
+ }
}
diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
index c861597c6b..8e1ad27e14 100644
--- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
@@ -16,7 +16,7 @@ class DummyTaskSetManager(
initNumTasks: Int,
clusterScheduler: ClusterScheduler,
taskSet: TaskSet)
- extends TaskSetManager(clusterScheduler,taskSet) {
+ extends ClusterTaskSetManager(clusterScheduler,taskSet) {
parent = null
weight = 1
diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000000..4000c4d520
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,105 @@
+package spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import spark._
+import spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("inner method") {
+ sc = new SparkContext("local", "joblogger")
+ val joblogger = new JobLogger {
+ def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+ def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+ def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+ def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
+ }
+ type MyRDD = RDD[(Int, Int)]
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]]
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ }
+ }
+ val jobID = 5
+ val parentRdd = makeRdd(4, Nil)
+ val shuffleDep = new ShuffleDependency(parentRdd, null)
+ val rootRdd = makeRdd(4, List(shuffleDep))
+ val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
+ val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
+ joblogger.getEventQueue.size should be (1)
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ parentRdd.setName("MyRDD")
+ joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+ joblogger.createLogWriterTest(jobID)
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.buildJobDepTest(jobID, rootStage)
+ joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+ joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+ joblogger.closeLogWriterTest(jobID)
+ joblogger.getStageIDToJobID.size should be (0)
+ joblogger.getJobIDToStages.size should be (0)
+ joblogger.getJobIDtoPrintWriter.size should be (0)
+ }
+
+ test("inner variables") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ override protected def closeLogWriter(jobID: Int) =
+ getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ }
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.getStageIDToJobID.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(0))
+ joblogger.getStageIDToJobID.get(1) should be (Some(0))
+ joblogger.getJobIDToStages.size should be (1)
+ }
+
+
+ test("interface functions") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ var onTaskEndCount = 0
+ var onJobEndCount = 0
+ var onJobStartCount = 0
+ var onStageCompletedCount = 0
+ var onStageSubmittedCount = 0
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+ override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+ override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.onJobStartCount should be (1)
+ joblogger.onJobEndCount should be (1)
+ joblogger.onTaskEndCount should be (8)
+ joblogger.onStageSubmittedCount should be (2)
+ joblogger.onStageCompletedCount should be (2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
new file mode 100644
index 0000000000..8bd813fd14
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
@@ -0,0 +1,206 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.CountDownLatch
+import java.util.Properties
+
+class Lock() {
+ var finished = false
+ def jobWait() = {
+ synchronized {
+ while(!finished) {
+ this.wait()
+ }
+ }
+ }
+
+ def jobFinished() = {
+ synchronized {
+ finished = true
+ this.notifyAll()
+ }
+ }
+}
+
+object TaskThreadInfo {
+ val threadToLock = HashMap[Int, Lock]()
+ val threadToRunning = HashMap[Int, Boolean]()
+ val threadToStarted = HashMap[Int, CountDownLatch]()
+}
+
+/*
+ * 1. each thread contains one job.
+ * 2. each job contains one stage.
+ * 3. each stage only contains one task.
+ * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
+ * it will get cpu core resource, and will wait to finished after user manually
+ * release "Lock" and then cluster will contain another free cpu cores.
+ * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
+ * thus it will be scheduled later when cluster has free cpu cores.
+ */
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+ def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ val nums = sc.parallelize(threadIndex to threadIndex, 1)
+ TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+ TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
+ new Thread {
+ if (poolName != null) {
+ sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
+ }
+ override def run() {
+ val ans = nums.map(number => {
+ TaskThreadInfo.threadToRunning(number) = true
+ TaskThreadInfo.threadToStarted(number).countDown()
+ TaskThreadInfo.threadToLock(number).jobWait()
+ TaskThreadInfo.threadToRunning(number) = false
+ number
+ }).collect()
+ assert(ans.toList === List(threadIndex))
+ sem.release()
+ }
+ }.start()
+ }
+
+ test("Local FIFO scheduler end-to-end test") {
+ System.setProperty("spark.cluster.schedulingmode", "FIFO")
+ sc = new SparkContext("local[4]", "test")
+ val sem = new Semaphore(0)
+
+ createThread(1,null,sc,sem)
+ TaskThreadInfo.threadToStarted(1).await()
+ createThread(2,null,sc,sem)
+ TaskThreadInfo.threadToStarted(2).await()
+ createThread(3,null,sc,sem)
+ TaskThreadInfo.threadToStarted(3).await()
+ createThread(4,null,sc,sem)
+ TaskThreadInfo.threadToStarted(4).await()
+ // thread 5 and 6 (stage pending)must meet following two points
+ // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
+ // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
+ // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
+ // So I just use "sleep" 1s here for each thread.
+ // TODO: any better solution?
+ createThread(5,null,sc,sem)
+ Thread.sleep(1000)
+ createThread(6,null,sc,sem)
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === true)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === false)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(1).jobFinished()
+ TaskThreadInfo.threadToStarted(5).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(3).jobFinished()
+ TaskThreadInfo.threadToStarted(6).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === false)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === true)
+
+ TaskThreadInfo.threadToLock(2).jobFinished()
+ TaskThreadInfo.threadToLock(4).jobFinished()
+ TaskThreadInfo.threadToLock(5).jobFinished()
+ TaskThreadInfo.threadToLock(6).jobFinished()
+ sem.acquire(6)
+ }
+
+ test("Local fair scheduler end-to-end test") {
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+ createThread(10,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(10).await()
+ createThread(20,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(20).await()
+ createThread(30,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(30).await()
+
+ assert(TaskThreadInfo.threadToRunning(10) === true)
+ assert(TaskThreadInfo.threadToRunning(20) === true)
+ assert(TaskThreadInfo.threadToRunning(30) === true)
+
+ createThread(11,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(11).await()
+ createThread(21,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(21).await()
+ createThread(31,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(31).await()
+
+ assert(TaskThreadInfo.threadToRunning(11) === true)
+ assert(TaskThreadInfo.threadToRunning(21) === true)
+ assert(TaskThreadInfo.threadToRunning(31) === true)
+
+ createThread(12,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(12).await()
+ createThread(22,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(22).await()
+ createThread(32,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(12) === true)
+ assert(TaskThreadInfo.threadToRunning(22) === true)
+ assert(TaskThreadInfo.threadToRunning(32) === false)
+
+ TaskThreadInfo.threadToLock(10).jobFinished()
+ TaskThreadInfo.threadToStarted(32).await()
+
+ assert(TaskThreadInfo.threadToRunning(32) === true)
+
+ //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
+ // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
+ //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
+ createThread(23,"2",sc,sem)
+ createThread(33,"3",sc,sem)
+ Thread.sleep(1000)
+
+ TaskThreadInfo.threadToLock(11).jobFinished()
+ TaskThreadInfo.threadToStarted(23).await()
+
+ assert(TaskThreadInfo.threadToRunning(23) === true)
+ assert(TaskThreadInfo.threadToRunning(33) === false)
+
+ TaskThreadInfo.threadToLock(12).jobFinished()
+ TaskThreadInfo.threadToStarted(33).await()
+
+ assert(TaskThreadInfo.threadToRunning(33) === true)
+
+ TaskThreadInfo.threadToLock(20).jobFinished()
+ TaskThreadInfo.threadToLock(21).jobFinished()
+ TaskThreadInfo.threadToLock(22).jobFinished()
+ TaskThreadInfo.threadToLock(23).jobFinished()
+ TaskThreadInfo.threadToLock(30).jobFinished()
+ TaskThreadInfo.threadToLock(31).jobFinished()
+ TaskThreadInfo.threadToLock(32).jobFinished()
+ TaskThreadInfo.threadToLock(33).jobFinished()
+
+ sem.acquire(11)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
index 42a87d8b90..48aa67c543 100644
--- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
- def onStageCompleted(stage: StageCompleted) {
+ override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}
diff --git a/examples/pom.xml b/examples/pom.xml
index c42d2bcdb9..3e5271ec2f 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -34,6 +34,41 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.cassandra</groupId>
+ <artifactId>cassandra-all</artifactId>
+ <version>1.2.5</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.googlecode.concurrentlinkedhashmap</groupId>
+ <artifactId>concurrentlinkedhashmap-lru</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.ning</groupId>
+ <artifactId>compress-lzf</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>jline</groupId>
+ <artifactId>jline</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.cassandra.deps</groupId>
+ <artifactId>avro</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
@@ -67,6 +102,11 @@
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase</artifactId>
+ <version>0.94.6</version>
+ </dependency>
</dependencies>
<build>
<plugins>
@@ -105,6 +145,11 @@
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase</artifactId>
+ <version>0.94.6</version>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala
new file mode 100644
index 0000000000..0fe1833e83
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/CassandraTest.scala
@@ -0,0 +1,196 @@
+package spark.examples
+
+import org.apache.hadoop.mapreduce.Job
+import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat
+import org.apache.cassandra.hadoop.ConfigHelper
+import org.apache.cassandra.hadoop.ColumnFamilyInputFormat
+import org.apache.cassandra.thrift._
+import spark.SparkContext
+import spark.SparkContext._
+import java.nio.ByteBuffer
+import java.util.SortedMap
+import org.apache.cassandra.db.IColumn
+import org.apache.cassandra.utils.ByteBufferUtil
+import scala.collection.JavaConversions._
+
+
+/*
+ * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra
+ * support for Hadoop.
+ *
+ * To run this example, run this file with the following command params -
+ * <spark_master> <cassandra_node> <cassandra_port>
+ *
+ * So if you want to run this on localhost this will be,
+ * local[3] localhost 9160
+ *
+ * The example makes some assumptions:
+ * 1. You have already created a keyspace called casDemo and it has a column family named Words
+ * 2. There are column family has a column named "para" which has test content.
+ *
+ * You can create the content by running the following script at the bottom of this file with
+ * cassandra-cli.
+ *
+ */
+object CassandraTest {
+
+ def main(args: Array[String]) {
+
+ // Get a SparkContext
+ val sc = new SparkContext(args(0), "casDemo")
+
+ // Build the job configuration with ConfigHelper provided by Cassandra
+ val job = new Job()
+ job.setInputFormatClass(classOf[ColumnFamilyInputFormat])
+
+ val host: String = args(1)
+ val port: String = args(2)
+
+ ConfigHelper.setInputInitialAddress(job.getConfiguration(), host)
+ ConfigHelper.setInputRpcPort(job.getConfiguration(), port)
+ ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host)
+ ConfigHelper.setOutputRpcPort(job.getConfiguration(), port)
+ ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words")
+ ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount")
+
+ val predicate = new SlicePredicate()
+ val sliceRange = new SliceRange()
+ sliceRange.setStart(Array.empty[Byte])
+ sliceRange.setFinish(Array.empty[Byte])
+ predicate.setSlice_range(sliceRange)
+ ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate)
+
+ ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
+ ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
+
+ // Make a new Hadoop RDD
+ val casRdd = sc.newAPIHadoopRDD(
+ job.getConfiguration(),
+ classOf[ColumnFamilyInputFormat],
+ classOf[ByteBuffer],
+ classOf[SortedMap[ByteBuffer, IColumn]])
+
+ // Let us first get all the paragraphs from the retrieved rows
+ val paraRdd = casRdd.map {
+ case (key, value) => {
+ ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value())
+ }
+ }
+
+ // Lets get the word count in paras
+ val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _)
+
+ counts.collect().foreach {
+ case (word, count) => println(word + ":" + count)
+ }
+
+ counts.map {
+ case (word, count) => {
+ val colWord = new org.apache.cassandra.thrift.Column()
+ colWord.setName(ByteBufferUtil.bytes("word"))
+ colWord.setValue(ByteBufferUtil.bytes(word))
+ colWord.setTimestamp(System.currentTimeMillis)
+
+ val colCount = new org.apache.cassandra.thrift.Column()
+ colCount.setName(ByteBufferUtil.bytes("wcount"))
+ colCount.setValue(ByteBufferUtil.bytes(count.toLong))
+ colCount.setTimestamp(System.currentTimeMillis)
+
+ val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis)
+
+ val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil
+ mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn())
+ mutations.get(0).column_or_supercolumn.setColumn(colWord)
+ mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn())
+ mutations.get(1).column_or_supercolumn.setColumn(colCount)
+ (outputkey, mutations)
+ }
+ }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
+ classOf[ColumnFamilyOutputFormat], job.getConfiguration)
+ }
+}
+
+/*
+create keyspace casDemo;
+use casDemo;
+
+create column family WordCount with comparator = UTF8Type;
+update column family WordCount with column_metadata =
+ [{column_name: word, validation_class: UTF8Type},
+ {column_name: wcount, validation_class: LongType}];
+
+create column family Words with comparator = UTF8Type;
+update column family Words with column_metadata =
+ [{column_name: book, validation_class: UTF8Type},
+ {column_name: para, validation_class: UTF8Type}];
+
+assume Words keys as utf8;
+
+set Words['3musk001']['book'] = 'The Three Musketeers';
+set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market
+ town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to
+ be in as perfect a state of revolution as if the Huguenots had just made
+ a second La Rochelle of it. Many citizens, seeing the women flying
+ toward the High Street, leaving their children crying at the open doors,
+ hastened to don the cuirass, and supporting their somewhat uncertain
+ courage with a musket or a partisan, directed their steps toward the
+ hostelry of the Jolly Miller, before which was gathered, increasing
+ every minute, a compact group, vociferous and full of curiosity.';
+
+set Words['3musk002']['book'] = 'The Three Musketeers';
+set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without
+ some city or other registering in its archives an event of this kind. There were
+ nobles, who made war against each other; there was the king, who made
+ war against the cardinal; there was Spain, which made war against the
+ king. Then, in addition to these concealed or public, secret or open
+ wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels,
+ who made war upon everybody. The citizens always took up arms readily
+ against thieves, wolves or scoundrels, often against nobles or
+ Huguenots, sometimes against the king, but never against cardinal or
+ Spain. It resulted, then, from this habit that on the said first Monday
+ of April, 1625, the citizens, on hearing the clamor, and seeing neither
+ the red-and-yellow standard nor the livery of the Duc de Richelieu,
+ rushed toward the hostel of the Jolly Miller. When arrived there, the
+ cause of the hubbub was apparent to all';
+
+set Words['3musk003']['book'] = 'The Three Musketeers';
+set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however
+ large the sum may be; but you ought also to endeavor to perfect yourself in
+ the exercises becoming a gentleman. I will write a letter today to the
+ Director of the Royal Academy, and tomorrow he will admit you without
+ any expense to yourself. Do not refuse this little service. Our
+ best-born and richest gentlemen sometimes solicit it without being able
+ to obtain it. You will learn horsemanship, swordsmanship in all its
+ branches, and dancing. You will make some desirable acquaintances; and
+ from time to time you can call upon me, just to tell me how you are
+ getting on, and to say whether I can be of further service to you.';
+
+
+set Words['thelostworld001']['book'] = 'The Lost World';
+set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined
+ against the red curtain. How beautiful she was! And yet how aloof! We had been
+ friends, quite good friends; but never could I get beyond the same
+ comradeship which I might have established with one of my
+ fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly,
+ and perfectly unsexual. My instincts are all against a woman being too
+ frank and at her ease with me. It is no compliment to a man. Where
+ the real sex feeling begins, timidity and distrust are its companions,
+ heritage from old wicked days when love and violence went often hand in
+ hand. The bent head, the averted eye, the faltering voice, the wincing
+ figure--these, and not the unshrinking gaze and frank reply, are the
+ true signals of passion. Even in my short life I had learned as much
+ as that--or had inherited it in that race memory which we call instinct.';
+
+set Words['thelostworld002']['book'] = 'The Lost World';
+set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed,
+ red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was
+ the real boss; but he lived in the rarefied atmosphere of some Olympian
+ height from which he could distinguish nothing smaller than an
+ international crisis or a split in the Cabinet. Sometimes we saw him
+ passing in lonely majesty to his inner sanctum, with his eyes staring
+ vaguely and his mind hovering over the Balkans or the Persian Gulf. He
+ was above and beyond us. But McArdle was his first lieutenant, and it
+ was he that we knew. The old man nodded as I entered the room, and he
+ pushed his spectacles far up on his bald forehead.';
+
+*/
diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala
new file mode 100644
index 0000000000..6e910154d4
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/HBaseTest.scala
@@ -0,0 +1,35 @@
+package spark.examples
+
+import spark._
+import spark.rdd.NewHadoopRDD
+import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor}
+import org.apache.hadoop.hbase.client.HBaseAdmin
+import org.apache.hadoop.hbase.mapreduce.TableInputFormat
+
+object HBaseTest {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "HBaseTest",
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
+ val conf = HBaseConfiguration.create()
+
+ // Other options for configuring scan behavior are available. More information available at
+ // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html
+ conf.set(TableInputFormat.INPUT_TABLE, args(1))
+
+ // Initialize hBase table if necessary
+ val admin = new HBaseAdmin(conf)
+ if(!admin.isTableAvailable(args(1))) {
+ val tableDesc = new HTableDescriptor(args(1))
+ admin.createTable(tableDesc)
+ }
+
+ val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat],
+ classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable],
+ classOf[org.apache.hadoop.hbase.client.Result])
+
+ hBaseRDD.count()
+
+ System.exit(0)
+ }
+} \ No newline at end of file
diff --git a/pom.xml b/pom.xml
index ce77ba37c6..3bcb2a3f34 100644
--- a/pom.xml
+++ b/pom.xml
@@ -60,7 +60,7 @@
<cdh.version>4.1.2</cdh.version>
<log4j.version>1.2.17</log4j.version>
- <PermGen>0m</PermGen>
+ <PermGen>64m</PermGen>
<MaxPermGen>512m</MaxPermGen>
</properties>
@@ -190,9 +190,9 @@
<version>0.8.4</version>
</dependency>
<dependency>
- <groupId>asm</groupId>
- <artifactId>asm-all</artifactId>
- <version>3.3.1</version>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
+ <version>4.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
@@ -395,10 +395,8 @@
<jvmArgs>
<jvmArg>-Xms64m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
- <jvmArg>-XX:PermSize</jvmArg>
- <jvmArg>${PermGen}</jvmArg>
- <jvmArg>-XX:MaxPermSize</jvmArg>
- <jvmArg>${MaxPermGen}</jvmArg>
+ <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
+ <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
</jvmArgs>
<javacArgs>
<javacArg>-source</javacArg>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 0ea23b446f..faf6e2ae8e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -54,7 +54,7 @@ object SparkBuild extends Build {
// Fork new JVMs for tests and set Java options for those
fork := true,
- javaOptions += "-Xmx2g",
+ javaOptions += "-Xmx2500m",
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@@ -125,12 +125,13 @@ object SparkBuild extends Build {
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
- )
+ ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings
- val slf4jVersion = "1.6.1"
+ val slf4jVersion = "1.7.2"
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
+ val excludeAsm = ExclusionRule(organization = "asm")
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
@@ -148,7 +149,7 @@ object SparkBuild extends Build {
"org.slf4j" % "slf4j-log4j12" % slf4jVersion,
"commons-daemon" % "commons-daemon" % "1.0.10",
"com.ning" % "compress-lzf" % "0.8.4",
- "asm" % "asm-all" % "3.3.1",
+ "org.ow2.asm" % "asm" % "4.0",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.22",
"com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty),
@@ -201,7 +202,20 @@ object SparkBuild extends Build {
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
- libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11")
+ libraryDependencies ++= Seq(
+ "com.twitter" % "algebird-core_2.9.2" % "0.1.11",
+
+ "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
+
+ "org.apache.cassandra" % "cassandra-all" % "1.2.5"
+ exclude("com.google.guava", "guava")
+ exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
+ exclude("com.ning","compress-lzf")
+ exclude("io.netty", "netty")
+ exclude("jline","jline")
+ exclude("log4j","log4j")
+ exclude("org.apache.cassandra.deps", "avro")
+ )
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
@@ -210,7 +224,7 @@ object SparkBuild extends Build {
name := "spark-streaming",
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
- "com.github.sgroschupf" % "zkclient" % "0.1",
+ "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty)
)
diff --git a/project/plugins.sbt b/project/plugins.sbt
index d4f2442872..f806e66481 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
+
+addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3")
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
new file mode 100644
index 0000000000..78a2da1e18
--- /dev/null
+++ b/python/pyspark/daemon.py
@@ -0,0 +1,158 @@
+import os
+import sys
+import multiprocessing
+from ctypes import c_bool
+from errno import EINTR, ECHILD
+from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
+from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from pyspark.worker import main as worker_main
+from pyspark.serializers import write_int
+
+try:
+ POOLSIZE = multiprocessing.cpu_count()
+except NotImplementedError:
+ POOLSIZE = 4
+
+exit_flag = multiprocessing.Value(c_bool, False)
+
+
+def should_exit():
+ global exit_flag
+ return exit_flag.value
+
+
+def compute_real_exit_code(exit_code):
+ # SystemExit's code can be integer or string, but os._exit only accepts integers
+ import numbers
+ if isinstance(exit_code, numbers.Integral):
+ return exit_code
+ else:
+ return 1
+
+
+def worker(listen_sock):
+ # Redirect stdout to stderr
+ os.dup2(2, 1)
+
+ # Manager sends SIGHUP to request termination of workers in the pool
+ def handle_sighup(*args):
+ assert should_exit()
+ signal(SIGHUP, handle_sighup)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ pid = status = None
+ try:
+ while (pid, status) != (0, 0):
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except EnvironmentError as err:
+ if err.errno == EINTR:
+ # retry
+ handle_sigchld()
+ elif err.errno != ECHILD:
+ raise
+ signal(SIGCHLD, handle_sigchld)
+
+ # Handle clients
+ while not should_exit():
+ # Wait until a client arrives or we have to exit
+ sock = None
+ while not should_exit() and sock is None:
+ try:
+ sock, addr = listen_sock.accept()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ raise
+
+ if sock is not None:
+ # Fork a child to handle the client.
+ # The client is handled in the child so that the manager
+ # never receives SIGCHLD unless a worker crashes.
+ if os.fork() == 0:
+ # Leave the worker pool
+ signal(SIGHUP, SIG_DFL)
+ listen_sock.close()
+ # Handle the client then exit
+ sockfile = sock.makefile()
+ exit_code = 0
+ try:
+ worker_main(sockfile, sockfile)
+ except SystemExit as exc:
+ exit_code = exc.code
+ finally:
+ sockfile.close()
+ sock.close()
+ os._exit(compute_real_exit_code(exit_code))
+ else:
+ sock.close()
+
+
+def launch_worker(listen_sock):
+ if os.fork() == 0:
+ try:
+ worker(listen_sock)
+ except Exception as err:
+ import traceback
+ traceback.print_exc()
+ os._exit(1)
+ else:
+ assert should_exit()
+ os._exit(0)
+
+
+def manager():
+ # Create a new process group to corral our children
+ os.setpgid(0, 0)
+
+ # Create a listening socket on the AF_INET loopback interface
+ listen_sock = socket(AF_INET, SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
+ listen_host, listen_port = listen_sock.getsockname()
+ write_int(listen_port, sys.stdout)
+
+ # Launch initial worker pool
+ for idx in range(POOLSIZE):
+ launch_worker(listen_sock)
+ listen_sock.close()
+
+ def shutdown():
+ global exit_flag
+ exit_flag.value = True
+
+ # Gracefully exit on SIGTERM, don't die on SIGHUP
+ signal(SIGTERM, lambda signum, frame: shutdown())
+ signal(SIGHUP, SIG_IGN)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ if status != 0 and not should_exit():
+ raise RuntimeError("worker crashed: %s, %s" % (pid, status))
+ except EnvironmentError as err:
+ if err.errno not in (ECHILD, EINTR):
+ raise
+ signal(SIGCHLD, handle_sigchld)
+
+ # Initialization complete
+ sys.stdout.close()
+ try:
+ while not should_exit():
+ try:
+ # Spark tells us to exit by closing stdin
+ if os.read(0, 512) == '':
+ shutdown()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ shutdown()
+ raise
+ finally:
+ signal(SIGTERM, SIG_DFL)
+ exit_flag.value = True
+ # Send SIGHUP to notify workers of shutdown
+ os.kill(0, SIGHUP)
+
+
+if __name__ == '__main__':
+ manager()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 115cf28cc2..5a95144983 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -46,6 +46,10 @@ def read_long(stream):
return struct.unpack("!q", length)[0]
+def write_long(value, stream):
+ stream.write(struct.pack("!q", value))
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 812e7a9da5..379bbfd4c2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -3,6 +3,7 @@ Worker that receives input from Piped RDD.
"""
import os
import sys
+import time
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
-# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = os.fdopen(os.dup(1), 'w')
-os.dup2(2, 1)
+def load_obj(infile):
+ return load_pickle(standard_b64decode(infile.readline().strip()))
-def load_obj():
- return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+def report_times(outfile, boot, init, finish):
+ write_int(-3, outfile)
+ write_long(1000 * boot, outfile)
+ write_long(1000 * init, outfile)
+ write_long(1000 * finish, outfile)
-def main():
- split_index = read_int(sys.stdin)
- spark_files_dir = load_pickle(read_with_length(sys.stdin))
+def main(infile, outfile):
+ boot_time = time.time()
+ split_index = read_int(infile)
+ if split_index == -1: # for unit tests
+ return
+ spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
- num_broadcast_variables = read_int(sys.stdin)
+ num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
- bid = read_long(sys.stdin)
- value = read_with_length(sys.stdin)
+ bid = read_long(infile)
+ value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
- func = load_obj()
- bypassSerializer = load_obj()
+ func = load_obj(infile)
+ bypassSerializer = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
- iterator = read_from_pickle_file(sys.stdin)
+ init_time = time.time()
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ write_with_length(dumps(obj), outfile)
except Exception as e:
- write_int(-2, old_stdout)
- write_with_length(traceback.format_exc(), old_stdout)
+ write_int(-2, outfile)
+ write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
+ finish_time = time.time()
+ report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, old_stdout)
+ write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), old_stdout)
+ write_with_length(dump_pickle((aid, accum._value)), outfile)
+ write_int(-1, outfile)
if __name__ == '__main__':
- main()
+ # Redirect stdout to stderr so that users must return values from functions.
+ old_stdout = os.fdopen(os.dup(1), 'w')
+ os.dup2(2, 1)
+ main(sys.stdin, old_stdout)
diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
index 13d81ec1cf..0e9aa863b5 100644
--- a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
@@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.objectweb.asm._
-import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
@@ -83,7 +82,7 @@ extends ClassLoader(parent) {
}
class ConstructorCleaner(className: String, cv: ClassVisitor)
-extends ClassAdapter(cv) {
+extends ClassVisitor(ASM4, cv) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
val mv = cv.visitMethod(access, name, desc, sig, exceptions)
diff --git a/run b/run
index c0065c53f1..e656e38ccf 100755
--- a/run
+++ b/run
@@ -132,10 +132,14 @@ if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
fi
CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
+# Add the shaded JAR for Maven builds
if [ -e $REPL_BIN_DIR/target ]; then
for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH="$CLASSPATH:$jar"
done
+ # The shaded JAR doesn't contain examples, so include those separately
+ EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
+ CLASSPATH+=":$EXAMPLES_JAR"
fi
CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
@@ -148,9 +152,9 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ];
# Use the JAR from the SBT build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
fi
-if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then
+if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the Maven build
- export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar`
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
fi
# Add hadoop conf dir - else FileSystem.*, etc fail !