diff options
51 files changed, 2918 insertions, 983 deletions
diff --git a/core/pom.xml b/core/pom.xml index d8687bf991..88f0ed70f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,8 +32,8 @@ <artifactId>compress-lzf</artifactId> </dependency> <dependency> - <groupId>asm</groupId> - <artifactId>asm-all</artifactId> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> </dependency> <dependency> <groupId>com.google.protobuf</groupId> diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index e1fb02157a..3239f4c385 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics + shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 50d6a1c5c9..d5e7132ff9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -5,8 +5,7 @@ import java.lang.reflect.Field import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, MethodVisitor, Type} -import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.objectweb.asm.Opcodes._ import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} @@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging { } } -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten } } -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null override def visit(version: Int, access: Int, name: String, sig: String, @@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 2b0e697337..fa4bbfc76f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -10,6 +10,8 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.FileOutputCommitter import org.apache.hadoop.mapred.FileOutputFormat import org.apache.hadoop.mapred.HadoopWriter @@ -17,7 +19,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil} import spark.partial.BoundedDouble import spark.partial.PartialResult @@ -185,11 +187,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner) + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } @@ -516,6 +520,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + } + + /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ @@ -576,6 +590,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ def saveAsHadoopFile( @@ -583,11 +611,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) saveAsHadoopDataset(conf) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..f336c2ea1e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -7,12 +7,14 @@ import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import spark.broadcast.Broadcast import spark.Partitioner._ import spark.partial.BoundedDouble import spark.partial.CountEvaluator @@ -35,6 +37,7 @@ import spark.rdd.ZippedPartitionsRDD2 import spark.rdd.ZippedPartitionsRDD3 import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel +import spark.util.BoundedPriorityQueue import SparkContext._ @@ -114,6 +117,14 @@ abstract class RDD[T: ClassManifest]( this } + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass + + /** Reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. This can only be used to assign a new storage level if the RDD does not @@ -352,13 +363,36 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD + */ + def pipe( + command: Seq[String], + env: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) /** * Return a new RDD by applying a function to each partition of this RDD. @@ -723,6 +757,24 @@ abstract class RDD[T: ClassManifest]( } /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray + } + + /** * Save this RDD as a text file, using string representations of elements. */ def saveAsTextFile(path: String) { @@ -731,6 +783,14 @@ abstract class RDD[T: ClassManifest]( } /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + + /** * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) { @@ -788,7 +848,7 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite + private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index 518034e07b..2911f9036e 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.OutputCommitter import org.apache.hadoop.mapred.FileOutputCommitter +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.Writable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.BytesWritable @@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String) { + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u val keyClass = getWritableClass[K] @@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format) + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc05d08fd6..70a9d7698c 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -49,7 +49,6 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} - /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. @@ -630,7 +629,7 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) @@ -713,7 +712,7 @@ class SparkContext( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index be1a04d619..7ccde2e818 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,5 +1,8 @@ package spark +import collection.mutable +import serializer.Serializer + import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider @@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils +import spark.api.python.PythonWorkerFactory /** @@ -37,7 +41,10 @@ class SparkEnv ( // If executorId is NOT found, return defaultHostPort var executorIdToHostPort: Option[(String, String) => String]) { + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() + def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() shuffleFetcher.stop() @@ -50,6 +57,11 @@ class SparkEnv ( actorSystem.awaitTermination() } + def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { + synchronized { + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create() + } + } def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { val env = SparkEnv.get diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ec15326014..f3621c6bee 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -116,8 +116,8 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory after " + maxAttempts + - " attempts!") + throw new IOException("Failed to create a temp directory (under " + root + ") after " + + maxAttempts + " attempts!") } try { dir = new File(root, "spark-" + UUID.randomUUID.toString) @@ -522,13 +522,14 @@ private object Utils extends Logging { execute(command, new File(".")) } - + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getSparkCallSite: String = { + def getCallSiteInfo: CallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -540,6 +541,7 @@ private object Utils extends Logging { var firstUserFile = "<unknown>" var firstUserLine = 0 var finished = false + var firstUserClass = "<unknown>" for (el <- trace) { if (!finished) { @@ -554,13 +556,19 @@ private object Utils extends Logging { else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName + firstUserClass = el.getClassName finished = true } } } - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) + } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 30084df4e2..76051597b6 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -6,6 +6,7 @@ import java.util.Comparator import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} @@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) + } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index eb81ed64cd..626b499454 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] { */ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - } object JavaRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 9b74d1226f..b555f2030a 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,9 +1,10 @@ package spark.api.java -import java.util.{List => JList} +import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import spark.{SparkContext, Partition, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} @@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = + rdd.saveAsTextFile(path, codec) + /** * Save this RDD as a SequenceFile of serialized objects. */ @@ -351,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toDebugString(): String = { rdd.toDebugString } + + /** + * Returns the top K elements from this RDD as defined by + * the specified Comparator[T]. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def top(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the top K elements from this RDD using the + * natural ordering for T. + * @param num the number of top elements to return + * @return an array of top elements + */ + def top(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + top(num, comp) + } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 807119ca8c..63140cf37f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -2,10 +2,9 @@ package spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast @@ -16,7 +15,7 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], - envVars: java.util.Map[String, String], + envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], @@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest]( // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + def this(parent: RDD[T], command: String, envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]]) = @@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - - for ((variable, value) <- envVars) { - currentEnvVars.put(variable, value) - } - val proc = pb.start() + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val startTime = System.currentTimeMillis val env = SparkEnv.get - - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() + val worker = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - val dOut = new DataOutputStream(proc.getOutputStream) + val out = new PrintWriter(worker.getOutputStream) + val dOut = new DataOutputStream(worker.getOutputStream) // Partition index dOut.writeInt(split.index) // sparkFilesDir @@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest]( } dOut.flush() out.flush() - proc.getOutputStream.close() + worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(proc.getInputStream) + val stream = new DataInputStream(worker.getInputStream) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj - _nextObj = read() + if (hasNext) { + // FIXME: can deadlock if worker is waiting for us to + // respond to current message (currently irrelevant because + // output is shutdown before we read any input) + _nextObj = read() + } obj } @@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj + case -3 => + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + read case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() @@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest]( stream.readFully(obj) throw new PythonException(new String(obj)) case -1 => - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() + // We've finished the data section of the output, but we can still + // read some accumulator updates; let's do that, breaking when we + // get a negative length record. + var len2 = stream.readInt() + while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + len2 = stream.readInt() } new Array[Byte](0) } } catch { case eof: EOFException => { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - new Array[Byte](0) + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } case e => throw e } @@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PairwiseRDD: unexpected value: " + x) + case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -215,7 +211,7 @@ private[spark] object PythonRDD { dOut.write(s) dOut.writeByte(Pickle.STOP) } else { - throw new Exception("Unexpected RDD type") + throw new SparkException("Unexpected RDD type") } } diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala new file mode 100644 index 0000000000..8844411d73 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -0,0 +1,95 @@ +package spark.api.python + +import java.io.{DataInputStream, IOException} +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon() + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon() + startDaemon() + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + def stop() { + stopDaemon() + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + daemonPort = new DataInputStream(daemon.getInputStream).readInt() + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + } catch { + case e => { + stopDaemon() + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } + + daemon = null + daemonPort = 0 + } + } +} diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 890938d48b..8bebfafce4 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -104,6 +104,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() task.metrics.foreach{ m => + m.hostname = Utils.localHostName m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt } diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index a7c56c2371..1dc13754f9 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -2,6 +2,11 @@ package spark.executor class TaskMetrics extends Serializable { /** + * Host's name the task runs on + */ + var hostname: String = _ + + /** * Time taken on the executor to deserialize this task */ var executorDeserializeTime: Int = _ @@ -34,9 +39,14 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** + * Time when shuffle finishs + */ + var shuffleFinishTime: Long = _ + + /** * Total number of blocks fetched in a shuffle (remote or local) */ - var totalBlocksFetched : Int = _ + var totalBlocksFetched: Int = _ /** * Number of remote blocks fetched in a shuffle diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 7599ba1a02..8966f9f86e 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} +import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output. - * @param mapSideCombine flag indicating whether to merge values before shuffle step. + * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag + * is on, Spark does an extra pass over the data on the map side to merge + * all values belonging to the same key together. This can reduce the amount + * of data shuffled if and only if the number of distinct keys is very small, + * and the ratio of key size to value size is also very small. */ class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true, + val mapSideCombine: Boolean = false, val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 962a1b21ad..c0baf43d43 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -9,6 +9,7 @@ import scala.collection.mutable.ArrayBuffer import scala.io.Source import spark.{RDD, SparkEnv, Partition, TaskContext} +import spark.broadcast.Broadcast /** @@ -18,14 +19,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext} class PipedRDD[T: ClassManifest]( prev: RDD[T], command: Seq[String], - envVars: Map[String, String]) + envVars: Map[String, String], + printPipeContext: (String => Unit) => Unit, + printRDDElement: (T, String => Unit) => Unit) extends RDD[String](prev) { - def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this( + prev: RDD[T], + command: String, + envVars: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) + override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -52,8 +60,17 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + + // input the pipe context firstly + if (printPipeContext != null) { + printPipeContext(out.println(_)) + } for (elem <- firstParent[T].iterator(split, context)) { - out.println(elem) + if (printRDDElement != null) { + printRDDElement(elem, out.println(_)) + } else { + out.println(elem) + } } out.close() } diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index dd9f3c2680..b234428ab2 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) // Remove exact match and then do host local match. - val otherNodePreferredLocations = rddSplitZip.map(x => { - x._1.preferredLocations(x._2).map(hostPort => { - val host = Utils.parseHostPort(hostPort)._1 - - if (exactMatchLocations.contains(host)) null else host - }).filter(_ != null) - }) - val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y)) + val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1) + val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1)) + .reduce((x, y) => x.intersect(y)) + val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) } otherNodeLocalLocations ++ exactMatchLocations } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7feeb97542..f7d60be5db 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -298,6 +298,7 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) idToActiveJob(runId) = job activeJobs += job resultStageToJob(finalStage) = job @@ -311,6 +312,8 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task, + completion.reason, completion.taskInfo, completion.taskMetrics))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -321,6 +324,7 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } return true } @@ -468,6 +472,7 @@ class DAGScheduler( } } if (tasks.size > 0) { + sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size))) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) @@ -522,6 +527,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -662,7 +668,9 @@ class DAGScheduler( val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - job.listener.jobFailed(new SparkException("Job failed: " + reason)) + val error = new SparkException("Job failed: " + reason) + job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) activeJobs -= job resultStageToJob -= resultStage } diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala new file mode 100644 index 0000000000..178bfaba3d --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -0,0 +1,306 @@ +package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index a65140b145..bac984b5c9 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -1,27 +1,59 @@ package spark.scheduler +import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging} +import spark.{Logging, SparkContext, TaskEndReason, Utils} import spark.executor.TaskMetrics -trait SparkListener { - /** - * called when a stage is completed, with information on the completed stage - */ - def onStageCompleted(stageCompleted: StageCompleted) -} - sealed trait SparkListenerEvents +case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents + case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, + taskMetrics: TaskMetrics) extends SparkListenerEvents + +case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) + extends SparkListenerEvents + +case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) + extends SparkListenerEvents + +trait SparkListener { + /** + * Called when a stage is completed, with information on the completed stage + */ + def onStageCompleted(stageCompleted: StageCompleted) { } + + /** + * Called when a stage is submitted + */ + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + + /** + * Called when a task ends + */ + def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } + + /** + * Called when a job starts + */ + def onJobStart(jobStart: SparkListenerJobStart) { } + + /** + * Called when a job ends + */ + def onJobEnd(jobEnd: SparkListenerJobEnd) { } + +} /** * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: StageCompleted) { import spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 053d4b8e4a..3a0c29b27f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet) + val manager = new ClusterTaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala new file mode 100644 index 0000000000..d72b0bfc9f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -0,0 +1,747 @@ +package spark.scheduler.cluster + +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import spark._ +import spark.scheduler._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + // Must not be the constraint. + assert (constraint != TaskLocality.PROCESS_LOCAL) + + constraint match { + case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + val retval = TaskLocality.withName(str) + // Must not specify PROCESS_LOCAL ! + assert (retval != TaskLocality.PROCESS_LOCAL) + + retval + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); + // default to preserve earlier behavior + NODE_LOCAL + } + } + } +} + +/** + * Schedules the tasks within a single TaskSet in the ClusterScheduler. + */ +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet) + extends TaskSetManager + with Logging { + + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksFinished = 0 + + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent:Schedulable = null + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node (process local to container). These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing pending tasks with no locality preferences + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + // Map of recent exceptions (identified by string representation and + // top stack frame) to duplicate count (how many times the same + // exception has appeared) and time the full exception was + // printed. This should ideally be an LRU map that can drop old + // exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker generation and set it on all tasks + val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) + for (t <- tasks) { + t.generation = generation + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Note that it follows the hierarchy. + // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and + // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = { + + if (TaskLocality.PROCESS_LOCAL == taskLocality) { + // straight forward comparison ! Special case it. + val retval = new HashSet[String]() + scheduler.synchronized { + for (location <- _taskPreferredLocations) { + if (scheduler.isExecutorAliveOnHostPort(location)) { + retval += location + } + } + } + + return retval + } + + val taskPreferredLocations = + if (TaskLocality.NODE_LOCAL == taskLocality) { + _taskPreferredLocations + } else { + assert (TaskLocality.RACK_LOCAL == taskLocality) + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new HashSet[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (processLocalLocations.size == 0) + assert (hostLocalLocations.size == 0) + pendingTasksWithNoPrefs += index + } else { + + // process local locality + for (hostPort <- processLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + } + + // host locality (includes process local) + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality (includes process local and host local) + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) + list += index + } + } + + allPendingTasks += index + } + + // Return the pending tasks list for a given host port (process local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host Port (which would be process local) + def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + // This method also cleans up any tasks in the list that have already + // been launched, since we want that to happen lazily. + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !finished(index)) { + return Some(index) + } + } + return None + } + + // Return a speculative task for a given host if any are available. The task should not have an + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask + } + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (processLocalTask != None) { + return processLocalTask + } + + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) + if (localTask != None) { + return localTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). + val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) + if (noPrefTask != None) { + return noPrefTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + val nonLocalTask = findTaskFromList(allPendingTasks) + if (nonLocalTask != None) { + return nonLocalTask + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(hostPort, locality) + } + + private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { + Utils.checkHostPort(hostPort) + + val locs = task.preferredLocations + + locs.contains(hostPort) + } + + private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { + val locs = task.preferredLocations + + // If no preference, consider it as host local + if (locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY + } + + findTask(hostPort, locality) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val taskLocality = + if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else + TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) + // Do various bookkeeping + copiesRunning(index) += 1 + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + if (TaskLocality.NODE_LOCAL == taskLocality) { + lastPreferredLaunchTime = time + } + // Serialize and return the task + val startTime = System.currentTimeMillis + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = System.currentTimeMillis - startTime + increaseRunningTasks(1) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskFinished(tid, state, serializedData) + case TaskState.LOST => + taskLost(tid, state, serializedData) + case TaskState.FAILED => + taskLost(tid, state, serializedData) + case TaskState.KILLED => + taskLost(tid, state, serializedData) + case _ => + } + } + + def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + decreaseRunningTasks(1) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + try { + val result = ser.deserialize[TaskResult[_]](serializedData) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread().getContextClassLoader + throw new SparkException("ClassNotFound with classloader: " + loader, cnf) + case ex => throw ex + } + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + decreaseRunningTasks(1) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + return + + case taskResultTooBig: TaskResultTooBigFailure => + logInfo("Loss was due to task %s result exceeding Akka frame size; " + + "aborting job".format(tid)) + abort("Task %s result exceeded Akka frame size".format(tid)) + return + + case ef: ExceptionFailure => + val key = ef.description + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) + sched.taskSetFinished(this) + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable:Schedulable) { + //nothing + } + + override def removeSchedulable(schedulable:Schedulable) { + //nothing + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + override def executorLost(execId: String, hostPort: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + + // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to + // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. + // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no process local node for the task) + for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { + // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + taskLost(tid, TaskState.KILLED, null) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the ClusterScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = System.currentTimeMillis() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.hostPort, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index f1c6266bac..b4dd75d90f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,747 +1,17 @@ package spark.scheduler.cluster -import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark._ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer -private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { - - // process local is expected to be used ONLY within tasksetmanager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value - - type TaskLocality = Value - - def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { - - // Must not be the constraint. - assert (constraint != TaskLocality.PROCESS_LOCAL) - - constraint match { - case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL - case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL - // For anything else, allow - case _ => true - } - } - - def parse(str: String): TaskLocality = { - // better way to do this ? - try { - val retval = TaskLocality.withName(str) - // Must not specify PROCESS_LOCAL ! - assert (retval != TaskLocality.PROCESS_LOCAL) - - retval - } catch { - case nEx: NoSuchElementException => { - logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); - // default to preserve earlier behavior - NODE_LOCAL - } - } - } -} - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. - */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Schedulable - with Logging { - - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = 4 - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - var weight = 1 - var minShare = 0 - var runningTasks = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent:Schedulable = null - - // Last time when we launched a preferred task (for delay scheduling) - var lastPreferredLaunchTime = System.currentTimeMillis - - // List of pending tasks for each node (process local to container). These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node. - // Essentially, similar to pendingTasksForHostPort, except at host level - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node based on rack locality. - // Essentially, similar to pendingTasksForHost, except at rack level - private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List containing pending tasks with no locality preferences - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // List containing all pending tasks (also used as a stack, as above) - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the job fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - // Map of recent exceptions (identified by string representation and - // top stack frame) to duplicate count (how many times the same - // exception has appeared) and time the full exception was - // printed. This should ideally be an LRU map that can drop old - // exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker generation and set it on all tasks - val generation = sched.mapOutputTracker.getGeneration - logDebug("Generation for " + taskSet.id + ": " + generation) - for (t <- tasks) { - t.generation = generation - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Note that it follows the hierarchy. - // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and - // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL - private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, - taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - - if (TaskLocality.PROCESS_LOCAL == taskLocality) { - // straight forward comparison ! Special case it. - val retval = new HashSet[String]() - scheduler.synchronized { - for (location <- _taskPreferredLocations) { - if (scheduler.isExecutorAliveOnHostPort(location)) { - retval += location - } - } - } - - return retval - } - - val taskPreferredLocations = - if (TaskLocality.NODE_LOCAL == taskLocality) { - _taskPreferredLocations - } else { - assert (TaskLocality.RACK_LOCAL == taskLocality) - // Expand set to include all 'seen' rack local hosts. - // This works since container allocation/management happens within master - so any rack locality information is updated in msater. - // Best case effort, and maybe sort of kludge for now ... rework it later ? - val hosts = new HashSet[String] - _taskPreferredLocations.foreach(h => { - val rackOpt = scheduler.getRackForHost(h) - if (rackOpt.isDefined) { - val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) - if (hostsOpt.isDefined) { - hosts ++= hostsOpt.get - } - } - - // Ensure that irrespective of what scheduler says, host is always added ! - hosts += h - }) - - hosts - } - - val retval = new HashSet[String] - scheduler.synchronized { - for (prefLocation <- taskPreferredLocations) { - val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) - if (aliveLocationsOpt.isDefined) { - retval ++= aliveLocationsOpt.get - } - } - } - - retval - } - - // Add a task to all the pending-task lists that it should be on. - private def addPendingTask(index: Int) { - // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate - // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - - if (rackLocalLocations.size == 0) { - // Current impl ensures this. - assert (processLocalLocations.size == 0) - assert (hostLocalLocations.size == 0) - pendingTasksWithNoPrefs += index - } else { - - // process local locality - for (hostPort <- processLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) - hostPortList += index - } - - // host locality (includes process local) - for (hostPort <- hostLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val host = Utils.parseHostPort(hostPort)._1 - val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) - hostList += index - } - - // rack locality (includes process local and host local) - for (rackLocalHostPort <- rackLocalLocations) { - // DEBUG Code - Utils.checkHostPort(rackLocalHostPort) - - val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 - val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) - list += index - } - } - - allPendingTasks += index - } - - // Return the pending tasks list for a given host port (process local), or an empty list if - // there is no map entry for that host - private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { - // DEBUG Code - Utils.checkHostPort(hostPort) - pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) - } - - // Return the pending tasks list for a given host, or an empty list if - // there is no map entry for that host - private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Return the pending tasks (rack level) list for a given host, or an empty list if - // there is no map entry for that host - private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Number of pending tasks for a given host Port (which would be process local) - def numPendingTasksForHostPort(hostPort: String): Int = { - getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending tasks for a given host (which would be data local) - def numPendingTasksForHost(hostPort: String): Int = { - getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending rack local tasks for a given host - def numRackLocalPendingTasksForHost(hostPort: String): Int = { - getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - - // Dequeue a pending task from the given list and return its index. - // Return None if the list is empty. - // This method also cleans up any tasks in the list that have already - // been launched, since we want that to happen lazily. - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if locality is set, the - // task must have a preference for this host/rack/no preferred locations at all. - private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - - assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - - if (speculatableTasks.size > 0) { - val localTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) - } - - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - - // check for rack locality - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - locations.contains(hostPort) && !attemptLocs.contains(hostPort) - } - - if (rackTask != None) { - speculatableTasks -= rackTask.get - return rackTask - } - } - - // Any task ... - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - // Check for attemptLocs also ? - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask - } - } - } - return None - } - - // Dequeue a pending task for a given node and return its index. - // If localOnly is set to false, allow non-local tasks as well. - private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) - if (processLocalTask != None) { - return processLocalTask - } - - val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) - if (localTask != None) { - return localTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) - if (rackLocalTask != None) { - return rackLocalTask - } - } - - // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. - // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). - val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) - if (noPrefTask != None) { - return noPrefTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - val nonLocalTask = findTaskFromList(allPendingTasks) - if (nonLocalTask != None) { - return nonLocalTask - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(hostPort, locality) - } - - private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { - Utils.checkHostPort(hostPort) - - val locs = task.preferredLocations - - locs.contains(hostPort) - } - - private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { - val locs = task.preferredLocations - - // If no preference, consider it as host local - if (locs.isEmpty) return true - - val host = Utils.parseHostPort(hostPort)._1 - locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined - } - - // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). - // This is true if either the task has preferred locations and this host is one, or it has - // no preferred locations (in which we still count the launch as preferred). - private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { - - val locs = task.preferredLocations - - val preferredRacks = new HashSet[String]() - for (preferredHost <- locs) { - val rack = sched.getRackForHost(preferredHost) - if (None != rack) preferredRacks += rack.get - } - - if (preferredRacks.isEmpty) return false - - val hostRack = sched.getRackForHost(hostPort) - - return None != hostRack && preferredRacks.contains(hostRack.get) - } - - // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - // If explicitly specified, use that - val locality = if (overrideLocality != null) overrideLocality else { - // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY - } - - findTask(hostPort, locality) match { - case Some(index) => { - // Found a task; do some bookkeeping and return a Mesos task for it - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - val taskLocality = - if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else - if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else - if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else - TaskLocality.ANY - val prefStr = taskLocality.toString - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, execId, hostPort, prefStr)) - // Do various bookkeeping - copiesRunning(index) += 1 - val time = System.currentTimeMillis - val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - if (TaskLocality.NODE_LOCAL == taskLocality) { - lastPreferredLaunchTime = time - } - // Serialize and return the task - val startTime = System.currentTimeMillis - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis - startTime - increaseRunningTasks(1) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) - } - case _ => - } - } - return None - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - decreaseRunningTasks(1) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - try { - val result = ser.deserialize[TaskResult[_]](serializedData) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - } catch { - case cnf: ClassNotFoundException => - val loader = Thread.currentThread().getContextClassLoader - throw new SparkException("ClassNotFound with classloader: " + loader, cnf) - case ex => throw ex - } - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - decreaseRunningTasks(1) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - return - - case taskResultTooBig: TaskResultTooBigFailure => - logInfo("Loss was due to task %s result exceeding Akka frame size;" + - "aborting job".format(tid)) - abort("Task %s result exceeded Akka frame size".format(tid)) - return - - case ef: ExceptionFailure => - val key = ef.description - val now = System.currentTimeMillis - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - decreaseRunningTasks(runningTasks) - sched.taskSetFinished(this) - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable:Schedulable) { - //nothing - } - - override def removeSchedulable(schedulable:Schedulable) { - //nothing - } - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def executorLost(execId: String, hostPort: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // If some task has preferred locations only on hostname, and there are no more executors there, - // put it in the no-prefs list to avoid the wait from delay scheduling - - // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to - // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no process local node for the task) - for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { - // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.hostPort, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksFinished < numTasks - } +private[spark] trait TaskSetManager extends Schedulable { + def taskSet: TaskSet + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, + overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + def numPendingTasksForHostPort(hostPort: String): Int + def numRackLocalPendingTasksForHost(hostPort :String): Int + def numPendingTasksForHost(hostPort: String): Int + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + def error(message: String) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 37a67f9b1b..93d4318b29 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -2,19 +2,50 @@ package spark.scheduler.local import java.io.File import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import spark._ +import spark.TaskState.TaskState import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.{TaskLocality, TaskInfo} +import spark.scheduler.cluster._ +import akka.actor._ /** - * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + +private[spark] case class LocalReviveOffers() +private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { + def receive = { + case LocalReviveOffers => + launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => + freeCores += 1 + localScheduler.statusUpdate(taskId, state, serializeData) + launchTask(localScheduler.resourceOffer(freeCores)) + } + + def launchTask(tasks : Seq[TaskDescription]) { + for (task <- tasks) { + freeCores -= 1 + localScheduler.threadPool.submit(new Runnable { + def run() { + localScheduler.runTask(task.taskId,task.serializedTask) + } + }) + } + } +} + +private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -30,89 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - // TODO: Need to take into account stage priority in scheduling + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskIdToTaskSetId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + var localActor: ActorRef = null + + override def start() { + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() - override def start() { } + localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") + } override def setListener(listener: TaskSchedulerListener) { this.listener = listener } override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - val failCount = new Array[Int](tasks.size) - - def submitTask(task: Task[_], idInJob: Int) { - val myAttemptId = attemptId.getAndIncrement() - threadPool.submit(new Runnable { - def run() { - runTask(task, idInJob, myAttemptId) - } - }) + synchronized { + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers } + } + + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { + synchronized { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + } - def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) - logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserStart = System.currentTimeMillis() - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - val deserTime = System.currentTimeMillis() - deserStart - - // Run it - val result: Any = deserializedTask.run(attemptId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - logInfo("Finished " + task) - info.markSuccessful() - deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - - // If the threadpool has not already been shutdown, notify DAGScheduler - if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null)) - } catch { - case t: Throwable => { - logError("Exception in task " + idInJob, t) - failCount.synchronized { - failCount(idInJob) += 1 - if (failCount(idInJob) <= maxFailures) { - submitTask(task, idInJob) - } else { - // TODO: Do something nicer here to return all the way to the user - if (!Thread.currentThread().isInterrupted) { - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) - listener.taskEnded(task, failure, null, null, info, null) - } + var launchTask = false + for (manager <- sortedTaskSetQueue) { + do { + launchTask = false + manager.slaveOffer(null,null,freeCpuCores) match { + case Some(task) => + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true + case None => {} } - } - } + } while(launchTask) } + return tasks } + } - for ((task, i) <- tasks.zipWithIndex) { - submitTask(task, i) + def taskSetFinished(manager: TaskSetManager) { + synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } + } + + def runTask(taskId: Long, bytes: ByteBuffer) { + logInfo("Running " + taskId) + val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) + val ser = SparkEnv.get.closureSerializer.newInstance() + try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; + // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserStart = System.currentTimeMillis() + val deserializedTask = ser.deserialize[Task[_]]( + taskBytes, Thread.currentThread.getContextClassLoader) + val deserTime = System.currentTimeMillis() - deserStart + + // Run it + val result: Any = deserializedTask.run(taskId) + + // Serialize and deserialize the result to emulate what the Mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val serResult = ser.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = ser.deserialize[Any](serResult) + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) + logInfo("Finished " + taskId) + deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt + + val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) + val serializedResult = ser.serialize(taskResult) + localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: Throwable => { + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) + localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) + } } } @@ -128,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) @@ -143,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } } - override def stop() { + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { + synchronized { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + taskSetManager.statusUpdate(taskId, state, serializedData) + } + } + + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala new file mode 100644 index 0000000000..70b69bb26f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -0,0 +1,172 @@ +package spark.scheduler.local + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val MAX_TASK_FAILURES = sched.maxFailures + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + SparkEnv.set(sched.env) + logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > MAX_TASK_FAILURES) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + decreaseRunningTasks(runningTasks) + sched.listener.taskSetFailed(taskSet, errorMessage) + // need to delete failed Taskset from schedule queue + sched.taskSetFinished(this) + } + } + } + + def error(message: String) { + } +} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 15ab840155..da859eebcb 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -96,15 +96,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def size(): Long = lastValidPosition } - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - var shuffleSender : ShuffleSender = null + private var shuffleSender : ShuffleSender = null // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. - val localDirs = createLocalDirs() - val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) addShutdownHook() @@ -113,7 +113,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) new DiskBlockObjectWriter(blockId, serializer, bufferSize) } - override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -249,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map(rootDir => { - var foundLocalDir: Boolean = false + rootDirs.split(",").map { rootDir => + var foundLocalDir = false var localDir: File = null var localDirId: String = null var tries = 0 @@ -265,7 +264,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) } } if (!foundLocalDir) { @@ -275,7 +274,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } logInfo("Created local directory at " + localDir) localDir - }) + } } private def addShutdownHook() { @@ -283,15 +282,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") - try { - localDirs.foreach { localDir => + localDirs.foreach { localDir => + try { if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) } - if (shuffleSender != null) { - shuffleSender.stop - } - } catch { - case t: Throwable => logError("Exception while deleting local spark dirs", t) + } + if (shuffleSender != null) { + shuffleSender.stop } } }) diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..4bc5db8bb7 --- /dev/null +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,45 @@ +package spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala index 5f80180339..2b980340b7 100644 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (other == this) { merge(other.copy()) // Avoid overwriting fields in a weird order } else { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - this + this } } diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..e61ff7793d 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -7,6 +7,8 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + import SparkContext._ @@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) } + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 93bb69b41c..d306124fca 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -8,6 +8,7 @@ import java.util.*; import scala.Tuple2; import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -474,6 +475,19 @@ public class JavaAPISuite implements Serializable { } @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List<String> expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD<String> readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test public void sequenceFile() { File tempDir = Files.createTempDir(); String outputDir = new File(tempDir, "output").getAbsolutePath(); @@ -620,6 +634,37 @@ public class JavaAPISuite implements Serializable { } @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List<Tuple2<Integer, String>> pairs = Arrays.asList( + new Tuple2<Integer, String>(1, "a"), + new Tuple2<Integer, String>(2, "aa"), + new Tuple2<Integer, String>(3, "aaa") + ); + JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() { + @Override + public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) { + return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>, + String>() { + @Override + public String call(Tuple2<IntWritable, Text> x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test public void zip() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() { diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 60db759c25..16f93e71a3 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,10 +1,10 @@ package spark import org.scalatest.FunSuite - import scala.collection.mutable.ArrayBuffer - import SparkContext._ +import spark.util.StatCounter +import scala.math.abs class PartitioningSuite extends FunSuite with LocalSparkContext { @@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) } + + test("Zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + sc = new SparkContext("local", "test") + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index a6344edf8f..ed075f93ec 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -19,6 +19,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(3) === "4") } + test("advanced pipe") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + test("pipe with env variable") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 3f69e99780..67f3332d44 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + + test("top with predefined ordering") { + sc = new SparkContext("local", "test") + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK.sorted === nums.sorted.takeRight(5)) + } + + test("top with custom ordering") { + sc = new SparkContext("local", "test") + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index c861597c6b..8e1ad27e14 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -16,7 +16,7 @@ class DummyTaskSetManager( initNumTasks: Int, clusterScheduler: ClusterScheduler, taskSet: TaskSet) - extends TaskSetManager(clusterScheduler,taskSet) { + extends ClusterTaskSetManager(clusterScheduler,taskSet) { parent = null weight = 1 diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..4000c4d520 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,105 @@ +package spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import spark._ +import spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) + + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4)) + joblogger.getEventQueue.size should be (1) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..8bd813fd14 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -0,0 +1,206 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.CountDownLatch +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() +} + +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) + new Thread { + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } + }.start() + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() + createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() + createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() + createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? + createThread(5,null,sc,sem) + Thread.sleep(1000) + createThread(6,null,sc,sem) + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + TaskThreadInfo.threadToStarted(5).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + TaskThreadInfo.threadToStarted(6).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() + createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() + createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() + createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() + createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() + createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + TaskThreadInfo.threadToStarted(32).await() + + assert(TaskThreadInfo.threadToRunning(32) === true) + + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + Thread.sleep(1000) + + TaskThreadInfo.threadToLock(11).jobFinished() + TaskThreadInfo.threadToStarted(23).await() + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + TaskThreadInfo.threadToStarted(33).await() + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 42a87d8b90..48aa67c543 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = mutable.Buffer[StageInfo]() - def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stageInfo } } diff --git a/examples/pom.xml b/examples/pom.xml index c42d2bcdb9..3e5271ec2f 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,41 @@ <artifactId>scalacheck_${scala.version}</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.apache.cassandra</groupId> + <artifactId>cassandra-all</artifactId> + <version>1.2.5</version> + <exclusions> + <exclusion> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </exclusion> + <exclusion> + <groupId>com.googlecode.concurrentlinkedhashmap</groupId> + <artifactId>concurrentlinkedhashmap-lru</artifactId> + </exclusion> + <exclusion> + <groupId>com.ning</groupId> + <artifactId>compress-lzf</artifactId> + </exclusion> + <exclusion> + <groupId>io.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + <exclusion> + <groupId>jline</groupId> + <artifactId>jline</artifactId> + </exclusion> + <exclusion> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + </exclusion> + <exclusion> + <groupId>org.apache.cassandra.deps</groupId> + <artifactId>avro</artifactId> + </exclusion> + </exclusions> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.version}/classes</outputDirectory> @@ -67,6 +102,11 @@ <artifactId>hadoop-core</artifactId> <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.hbase</groupId> + <artifactId>hbase</artifactId> + <version>0.94.6</version> + </dependency> </dependencies> <build> <plugins> @@ -105,6 +145,11 @@ <artifactId>hadoop-client</artifactId> <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.hbase</groupId> + <artifactId>hbase</artifactId> + <version>0.94.6</version> + </dependency> </dependencies> <build> <plugins> diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala new file mode 100644 index 0000000000..0fe1833e83 --- /dev/null +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -0,0 +1,196 @@ +package spark.examples + +import org.apache.hadoop.mapreduce.Job +import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat +import org.apache.cassandra.thrift._ +import spark.SparkContext +import spark.SparkContext._ +import java.nio.ByteBuffer +import java.util.SortedMap +import org.apache.cassandra.db.IColumn +import org.apache.cassandra.utils.ByteBufferUtil +import scala.collection.JavaConversions._ + + +/* + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra + * support for Hadoop. + * + * To run this example, run this file with the following command params - + * <spark_master> <cassandra_node> <cassandra_port> + * + * So if you want to run this on localhost this will be, + * local[3] localhost 9160 + * + * The example makes some assumptions: + * 1. You have already created a keyspace called casDemo and it has a column family named Words + * 2. There are column family has a column named "para" which has test content. + * + * You can create the content by running the following script at the bottom of this file with + * cassandra-cli. + * + */ +object CassandraTest { + + def main(args: Array[String]) { + + // Get a SparkContext + val sc = new SparkContext(args(0), "casDemo") + + // Build the job configuration with ConfigHelper provided by Cassandra + val job = new Job() + job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) + + val host: String = args(1) + val port: String = args(2) + + ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setInputRpcPort(job.getConfiguration(), port) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") + + val predicate = new SlicePredicate() + val sliceRange = new SliceRange() + sliceRange.setStart(Array.empty[Byte]) + sliceRange.setFinish(Array.empty[Byte]) + predicate.setSlice_range(sliceRange) + ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) + + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + // Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD( + job.getConfiguration(), + classOf[ColumnFamilyInputFormat], + classOf[ByteBuffer], + classOf[SortedMap[ByteBuffer, IColumn]]) + + // Let us first get all the paragraphs from the retrieved rows + val paraRdd = casRdd.map { + case (key, value) => { + ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) + } + } + + // Lets get the word count in paras + val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) + + counts.collect().foreach { + case (word, count) => println(word + ":" + count) + } + + counts.map { + case (word, count) => { + val colWord = new org.apache.cassandra.thrift.Column() + colWord.setName(ByteBufferUtil.bytes("word")) + colWord.setValue(ByteBufferUtil.bytes(word)) + colWord.setTimestamp(System.currentTimeMillis) + + val colCount = new org.apache.cassandra.thrift.Column() + colCount.setName(ByteBufferUtil.bytes("wcount")) + colCount.setValue(ByteBufferUtil.bytes(count.toLong)) + colCount.setTimestamp(System.currentTimeMillis) + + val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) + + val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(0).column_or_supercolumn.setColumn(colWord) + mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(1).column_or_supercolumn.setColumn(colCount) + (outputkey, mutations) + } + }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], + classOf[ColumnFamilyOutputFormat], job.getConfiguration) + } +} + +/* +create keyspace casDemo; +use casDemo; + +create column family WordCount with comparator = UTF8Type; +update column family WordCount with column_metadata = + [{column_name: word, validation_class: UTF8Type}, + {column_name: wcount, validation_class: LongType}]; + +create column family Words with comparator = UTF8Type; +update column family Words with column_metadata = + [{column_name: book, validation_class: UTF8Type}, + {column_name: para, validation_class: UTF8Type}]; + +assume Words keys as utf8; + +set Words['3musk001']['book'] = 'The Three Musketeers'; +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market + town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to + be in as perfect a state of revolution as if the Huguenots had just made + a second La Rochelle of it. Many citizens, seeing the women flying + toward the High Street, leaving their children crying at the open doors, + hastened to don the cuirass, and supporting their somewhat uncertain + courage with a musket or a partisan, directed their steps toward the + hostelry of the Jolly Miller, before which was gathered, increasing + every minute, a compact group, vociferous and full of curiosity.'; + +set Words['3musk002']['book'] = 'The Three Musketeers'; +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without + some city or other registering in its archives an event of this kind. There were + nobles, who made war against each other; there was the king, who made + war against the cardinal; there was Spain, which made war against the + king. Then, in addition to these concealed or public, secret or open + wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, + who made war upon everybody. The citizens always took up arms readily + against thieves, wolves or scoundrels, often against nobles or + Huguenots, sometimes against the king, but never against cardinal or + Spain. It resulted, then, from this habit that on the said first Monday + of April, 1625, the citizens, on hearing the clamor, and seeing neither + the red-and-yellow standard nor the livery of the Duc de Richelieu, + rushed toward the hostel of the Jolly Miller. When arrived there, the + cause of the hubbub was apparent to all'; + +set Words['3musk003']['book'] = 'The Three Musketeers'; +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however + large the sum may be; but you ought also to endeavor to perfect yourself in + the exercises becoming a gentleman. I will write a letter today to the + Director of the Royal Academy, and tomorrow he will admit you without + any expense to yourself. Do not refuse this little service. Our + best-born and richest gentlemen sometimes solicit it without being able + to obtain it. You will learn horsemanship, swordsmanship in all its + branches, and dancing. You will make some desirable acquaintances; and + from time to time you can call upon me, just to tell me how you are + getting on, and to say whether I can be of further service to you.'; + + +set Words['thelostworld001']['book'] = 'The Lost World'; +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined + against the red curtain. How beautiful she was! And yet how aloof! We had been + friends, quite good friends; but never could I get beyond the same + comradeship which I might have established with one of my + fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, + and perfectly unsexual. My instincts are all against a woman being too + frank and at her ease with me. It is no compliment to a man. Where + the real sex feeling begins, timidity and distrust are its companions, + heritage from old wicked days when love and violence went often hand in + hand. The bent head, the averted eye, the faltering voice, the wincing + figure--these, and not the unshrinking gaze and frank reply, are the + true signals of passion. Even in my short life I had learned as much + as that--or had inherited it in that race memory which we call instinct.'; + +set Words['thelostworld002']['book'] = 'The Lost World'; +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, + red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was + the real boss; but he lived in the rarefied atmosphere of some Olympian + height from which he could distinguish nothing smaller than an + international crisis or a split in the Cabinet. Sometimes we saw him + passing in lonely majesty to his inner sanctum, with his eyes staring + vaguely and his mind hovering over the Balkans or the Persian Gulf. He + was above and beyond us. But McArdle was his first lieutenant, and it + was he that we knew. The old man nodded as I entered the room, and he + pushed his spectacles far up on his bald forehead.'; + +*/ diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..6e910154d4 --- /dev/null +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -0,0 +1,35 @@ +package spark.examples + +import spark._ +import spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase table if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val tableDesc = new HTableDescriptor(args(1)) + admin.createTable(tableDesc) + } + + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result]) + + hBaseRDD.count() + + System.exit(0) + } +}
\ No newline at end of file @@ -60,7 +60,7 @@ <cdh.version>4.1.2</cdh.version> <log4j.version>1.2.17</log4j.version> - <PermGen>0m</PermGen> + <PermGen>64m</PermGen> <MaxPermGen>512m</MaxPermGen> </properties> @@ -190,9 +190,9 @@ <version>0.8.4</version> </dependency> <dependency> - <groupId>asm</groupId> - <artifactId>asm-all</artifactId> - <version>3.3.1</version> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> + <version>4.0</version> </dependency> <dependency> <groupId>com.google.protobuf</groupId> @@ -395,10 +395,8 @@ <jvmArgs> <jvmArg>-Xms64m</jvmArg> <jvmArg>-Xmx1024m</jvmArg> - <jvmArg>-XX:PermSize</jvmArg> - <jvmArg>${PermGen}</jvmArg> - <jvmArg>-XX:MaxPermSize</jvmArg> - <jvmArg>${MaxPermGen}</jvmArg> + <jvmArg>-XX:PermSize=${PermGen}</jvmArg> + <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg> </jvmArgs> <javacArgs> <javacArg>-source</javacArg> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0ea23b446f..faf6e2ae8e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -54,7 +54,7 @@ object SparkBuild extends Build { // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions += "-Xmx2g", + javaOptions += "-Xmx2500m", // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), @@ -125,12 +125,13 @@ object SparkBuild extends Build { publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn - ) + ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings - val slf4jVersion = "1.6.1" + val slf4jVersion = "1.7.2" val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeAsm = ExclusionRule(organization = "asm") def coreSettings = sharedSettings ++ Seq( name := "spark-core", @@ -148,7 +149,7 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "commons-daemon" % "commons-daemon" % "1.0.10", "com.ning" % "compress-lzf" % "0.8.4", - "asm" % "asm-all" % "3.3.1", + "org.ow2.asm" % "asm" % "4.0", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", "com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty), @@ -201,7 +202,20 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") + libraryDependencies ++= Seq( + "com.twitter" % "algebird-core_2.9.2" % "0.1.11", + + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm), + + "org.apache.cassandra" % "cassandra-all" % "1.2.5" + exclude("com.google.guava", "guava") + exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") + exclude("com.ning","compress-lzf") + exclude("io.netty", "netty") + exclude("jline","jline") + exclude("log4j","log4j") + exclude("org.apache.cassandra.deps", "avro") + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") @@ -210,7 +224,7 @@ object SparkBuild extends Build { name := "spark-streaming", libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty), - "com.github.sgroschupf" % "zkclient" % "0.1", + "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), "com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index d4f2442872..f806e66481 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") + +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3") diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py new file mode 100644 index 0000000000..78a2da1e18 --- /dev/null +++ b/python/pyspark/daemon.py @@ -0,0 +1,158 @@ +import os +import sys +import multiprocessing +from ctypes import c_bool +from errno import EINTR, ECHILD +from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN +from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from pyspark.worker import main as worker_main +from pyspark.serializers import write_int + +try: + POOLSIZE = multiprocessing.cpu_count() +except NotImplementedError: + POOLSIZE = 4 + +exit_flag = multiprocessing.Value(c_bool, False) + + +def should_exit(): + global exit_flag + return exit_flag.value + + +def compute_real_exit_code(exit_code): + # SystemExit's code can be integer or string, but os._exit only accepts integers + import numbers + if isinstance(exit_code, numbers.Integral): + return exit_code + else: + return 1 + + +def worker(listen_sock): + # Redirect stdout to stderr + os.dup2(2, 1) + + # Manager sends SIGHUP to request termination of workers in the pool + def handle_sighup(*args): + assert should_exit() + signal(SIGHUP, handle_sighup) + + # Cleanup zombie children + def handle_sigchld(*args): + pid = status = None + try: + while (pid, status) != (0, 0): + pid, status = os.waitpid(0, os.WNOHANG) + except EnvironmentError as err: + if err.errno == EINTR: + # retry + handle_sigchld() + elif err.errno != ECHILD: + raise + signal(SIGCHLD, handle_sigchld) + + # Handle clients + while not should_exit(): + # Wait until a client arrives or we have to exit + sock = None + while not should_exit() and sock is None: + try: + sock, addr = listen_sock.accept() + except EnvironmentError as err: + if err.errno != EINTR: + raise + + if sock is not None: + # Fork a child to handle the client. + # The client is handled in the child so that the manager + # never receives SIGCHLD unless a worker crashes. + if os.fork() == 0: + # Leave the worker pool + signal(SIGHUP, SIG_DFL) + listen_sock.close() + # Handle the client then exit + sockfile = sock.makefile() + exit_code = 0 + try: + worker_main(sockfile, sockfile) + except SystemExit as exc: + exit_code = exc.code + finally: + sockfile.close() + sock.close() + os._exit(compute_real_exit_code(exit_code)) + else: + sock.close() + + +def launch_worker(listen_sock): + if os.fork() == 0: + try: + worker(listen_sock) + except Exception as err: + import traceback + traceback.print_exc() + os._exit(1) + else: + assert should_exit() + os._exit(0) + + +def manager(): + # Create a new process group to corral our children + os.setpgid(0, 0) + + # Create a listening socket on the AF_INET loopback interface + listen_sock = socket(AF_INET, SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_host, listen_port = listen_sock.getsockname() + write_int(listen_port, sys.stdout) + + # Launch initial worker pool + for idx in range(POOLSIZE): + launch_worker(listen_sock) + listen_sock.close() + + def shutdown(): + global exit_flag + exit_flag.value = True + + # Gracefully exit on SIGTERM, don't die on SIGHUP + signal(SIGTERM, lambda signum, frame: shutdown()) + signal(SIGHUP, SIG_IGN) + + # Cleanup zombie children + def handle_sigchld(*args): + try: + pid, status = os.waitpid(0, os.WNOHANG) + if status != 0 and not should_exit(): + raise RuntimeError("worker crashed: %s, %s" % (pid, status)) + except EnvironmentError as err: + if err.errno not in (ECHILD, EINTR): + raise + signal(SIGCHLD, handle_sigchld) + + # Initialization complete + sys.stdout.close() + try: + while not should_exit(): + try: + # Spark tells us to exit by closing stdin + if os.read(0, 512) == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + finally: + signal(SIGTERM, SIG_DFL) + exit_flag.value = True + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) + + +if __name__ == '__main__': + manager() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 115cf28cc2..5a95144983 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -46,6 +46,10 @@ def read_long(stream): return struct.unpack("!q", length)[0] +def write_long(value, stream): + stream.write(struct.pack("!q", value)) + + def read_int(stream): length = stream.read(4) if length == "": diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6a1962d267..1e34d47365 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -12,6 +12,7 @@ import unittest from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME +from pyspark.serializers import read_int class PySparkTestCase(unittest.TestCase): @@ -117,5 +118,47 @@ class TestIO(PySparkTestCase): self.sc.parallelize([1]).foreach(func) +class TestDaemon(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send("\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") + daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + with self.assertRaises(EnvironmentError) as trap: + self.connect(port) + self.assertEqual(trap.exception.errno, ECONNREFUSED) + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 812e7a9da5..379bbfd4c2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -3,6 +3,7 @@ Worker that receives input from Piped RDD. """ import os import sys +import time import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file -# Redirect stdout to stderr so that users must return values from functions. -old_stdout = os.fdopen(os.dup(1), 'w') -os.dup2(2, 1) +def load_obj(infile): + return load_pickle(standard_b64decode(infile.readline().strip())) -def load_obj(): - return load_pickle(standard_b64decode(sys.stdin.readline().strip())) +def report_times(outfile, boot, init, finish): + write_int(-3, outfile) + write_long(1000 * boot, outfile) + write_long(1000 * init, outfile) + write_long(1000 * finish, outfile) -def main(): - split_index = read_int(sys.stdin) - spark_files_dir = load_pickle(read_with_length(sys.stdin)) +def main(infile, outfile): + boot_time = time.time() + split_index = read_int(infile) + if split_index == -1: # for unit tests + return + spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) - num_broadcast_variables = read_int(sys.stdin) + num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): - bid = read_long(sys.stdin) - value = read_with_length(sys.stdin) + bid = read_long(infile) + value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) - func = load_obj() - bypassSerializer = load_obj() + func = load_obj(infile) + bypassSerializer = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle - iterator = read_from_pickle_file(sys.stdin) + init_time = time.time() + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + write_with_length(dumps(obj), outfile) except Exception as e: - write_int(-2, old_stdout) - write_with_length(traceback.format_exc(), old_stdout) + write_int(-2, outfile) + write_with_length(traceback.format_exc(), outfile) sys.exit(-1) + finish_time = time.time() + report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, old_stdout) + write_int(-1, outfile) for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), old_stdout) + write_with_length(dump_pickle((aid, accum._value)), outfile) + write_int(-1, outfile) if __name__ == '__main__': - main() + # Redirect stdout to stderr so that users must return values from functions. + old_stdout = os.fdopen(os.dup(1), 'w') + os.dup2(2, 1) + main(sys.stdin, old_stdout) diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala index 13d81ec1cf..0e9aa863b5 100644 --- a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala @@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.objectweb.asm._ -import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ @@ -83,7 +82,7 @@ extends ClassLoader(parent) { } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassAdapter(cv) { +extends ClassVisitor(ASM4, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) @@ -132,10 +132,14 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" fi CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" +# Add the shaded JAR for Maven builds if [ -e $REPL_BIN_DIR/target ]; then for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH="$CLASSPATH:$jar" done + # The shaded JAR doesn't contain examples, so include those separately + EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` + CLASSPATH+=":$EXAMPLES_JAR" fi CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do @@ -148,9 +152,9 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; # Use the JAR from the SBT build export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar` fi -if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then +if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then # Use the JAR from the Maven build - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar` + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` fi # Add hadoop conf dir - else FileSystem.*, etc fail ! |