diff options
author | Reynold Xin <rxin@apache.org> | 2014-01-13 16:21:26 -0800 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-01-13 16:21:26 -0800 |
commit | e2d25d2dfeb1d43d1e36f169250d8efef4ac232a (patch) | |
tree | d911a37f5aacc89bc3a1c76d41842e1c156aec6a /core | |
parent | 8038da232870fe016e73122a2ef110ac8e56ca1e (diff) | |
parent | b93f9d42f21f03163734ef97b2871db945e166da (diff) | |
download | spark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.tar.gz spark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.tar.bz2 spark-e2d25d2dfeb1d43d1e36f169250d8efef4ac232a.zip |
Merge branch 'master' into graphx
Diffstat (limited to 'core')
122 files changed, 3167 insertions, 621 deletions
diff --git a/core/pom.xml b/core/pom.xml index aac0a9d11e..9e5a450d57 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -99,6 +99,11 @@ <artifactId>akka-slf4j_${scala.binary.version}</artifactId> </dependency> <dependency> + <groupId>${akka.group}</groupId> + <artifactId>akka-testkit_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> </dependency> @@ -166,6 +171,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.scalacheck</groupId> <artifactId>scalacheck_${scala.binary.version}</artifactId> <scope>test</scope> diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d72dbadc39..f7f8535594 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,8 +1,11 @@ # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose +# Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f73d234aa..e89ac28b8e 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -218,7 +218,7 @@ private object Accumulators { def newId: Long = synchronized { lastId += 1 - return lastId + lastId } def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 1a2ec55876..8b30cd4bfe 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.util.AppendOnlyMap +import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} /** * A set of functions used to aggregate data. @@ -31,30 +31,51 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { + private val sparkConf = SparkEnv.get.conf + private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) + def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) + } + combiners.iterator + } else { + val combiners = + new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + while (iter.hasNext) { + val (k, v) = iter.next() + combiners.insert(k, v) + } + combiners.iterator } - combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kc: (K, C) = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kc: Product2[K, C] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) + } + combiners.iterator + } else { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + while (iter.hasNext) { + val (k, c) = iter.next() + combiners.insert(k, c) + } + combiners.iterator } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator } } - diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 519ecde50a..8e5dd8a850 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values - return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) + new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => // Mark the split as loading (unless someone else marks it first) @@ -74,7 +74,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] elements ++= computedValues blockManager.put(key, elements, storageLevel, tellMaster = true) - return elements.iterator.asInstanceOf[Iterator[T]] + elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index ad1ee20045..a885898ad4 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -47,17 +47,17 @@ private[spark] class HttpFileServer extends Logging { def addFile(file: File) : String = { addFileToDir(file, fileDir) - return serverUri + "/files/" + file.getName + serverUri + "/files/" + file.getName } def addJar(file: File) : String = { addFileToDir(file, jarDir) - return serverUri + "/jars/" + file.getName + serverUri + "/jars/" + file.getName } def addFileToDir(file: File, dir: File) : String = { Files.copy(file, new File(dir, file.getName)) - return dir + "/" + file.getName + dir + "/" + file.getName } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 4a34989e50..9063cae87e 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -41,7 +41,7 @@ trait Logging { } log_ = LoggerFactory.getLogger(className) } - return log_ + log_ } // Log methods that take only a String diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 77b8ca1cce..30d182b008 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -32,15 +32,16 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) +private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) extends Actor with Logging { def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) + case GetMapOutputStatuses(shuffleId: Int) => + val hostPort = sender.path.address.hostPort + logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) sender ! tracker.getSerializedMapOutputStatuses(shuffleId) case StopMapOutputTracker => @@ -119,11 +120,10 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort(conf) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -139,7 +139,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - else{ + else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 31b0773bfe..fc0a749882 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -53,15 +53,16 @@ object Partitioner { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { - return new HashPartitioner(rdd.context.defaultParallelism) + new HashPartitioner(rdd.context.defaultParallelism) } else { - return new HashPartitioner(bySize.head.partitions.size) + new HashPartitioner(bySize.head.partitions.size) } } } /** - * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using + * Java's `Object.hashCode`. * * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will @@ -84,8 +85,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { } /** - * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. - * Determines the ranges by sampling the RDD passed in. + * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly + * equal ranges. The ranges are determined by sampling the content of the RDD passed in. */ class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e47f4e442..55ac76bf63 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -31,9 +31,9 @@ import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, -FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, -TextInputFormat} + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary @@ -49,7 +49,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType, -ClosureCleaner} + ClosureCleaner} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -116,7 +116,7 @@ class SparkContext( throw new SparkException("An application must be set in your configuration") } - if (conf.get("spark.logConf", "false").toBoolean) { + if (conf.getBoolean("spark.logConf", false)) { logInfo("Spark configuration:\n" + conf.toDebugString) } @@ -244,6 +244,10 @@ class SparkContext( localProperties.set(new Properties()) } + /** + * Set a local property that affects jobs submitted from this thread, such as the + * Spark fair scheduler pool. + */ def setLocalProperty(key: String, value: String) { if (localProperties.get() == null) { localProperties.set(new Properties()) @@ -255,6 +259,10 @@ class SparkContext( } } + /** + * Get a local property set in this thread, or null if it is missing. See + * [[org.apache.spark.SparkContext.setLocalProperty]]. + */ def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) @@ -265,7 +273,7 @@ class SparkContext( } /** - * Assigns a group id to all the jobs started by this thread until the group id is set to a + * Assigns a group ID to all the jobs started by this thread until the group ID is set to a * different value or cleared. * * Often, a unit of execution in an application consists of multiple Spark actions or jobs. @@ -288,7 +296,7 @@ class SparkContext( setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) } - /** Clear the job group id and its description. */ + /** Clear the current thread's job group ID and its description. */ def clearJobGroup() { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) @@ -337,29 +345,42 @@ class SparkContext( } /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other + * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), + * using the older MapReduce API (`org.apache.hadoop.mapred`). + * + * @param conf JobConf for setting up the dataset + * @param inputFormatClass Class of the [[InputFormat]] + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minSplits Minimum number of Hadoop Splits to generate. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. */ - def hadoopRDD[K, V]( + def hadoopRDD[K: ClassTag, V: ClassTag]( conf: JobConf, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int = defaultMinSplits + minSplits: Int = defaultMinSplits, + cloneRecords: Boolean = true ): RDD[(K, V)] = { // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V]( + def hadoopFile[K: ClassTag, V: ClassTag]( path: String, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int = defaultMinSplits + minSplits: Int = defaultMinSplits, + cloneRecords: Boolean = true ): RDD[(K, V)] = { // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) @@ -371,7 +392,8 @@ class SparkContext( inputFormatClass, keyClass, valueClass, - minSplits) + minSplits, + cloneRecords) } /** @@ -382,14 +404,15 @@ class SparkContext( * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) * }}} */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) - : RDD[(K, V)] = { + def hadoopFile[K, V, F <: InputFormat[K, V]] + (path: String, minSplits: Int, cloneRecords: Boolean = true) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { hadoopFile(path, - fm.runtimeClass.asInstanceOf[Class[F]], - km.runtimeClass.asInstanceOf[Class[K]], - vm.runtimeClass.asInstanceOf[Class[V]], - minSplits) + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + minSplits, + cloneRecords) } /** @@ -400,61 +423,67 @@ class SparkContext( * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, cloneRecords: Boolean = true) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = - hadoopFile[K, V, F](path, defaultMinSplits) + hadoopFile[K, V, F](path, defaultMinSplits, cloneRecords) /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] + (path: String, cloneRecords: Boolean = true) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { newAPIHadoopFile( - path, - fm.runtimeClass.asInstanceOf[Class[F]], - km.runtimeClass.asInstanceOf[Class[K]], - vm.runtimeClass.asInstanceOf[Class[V]]) + path, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + cloneRecords = cloneRecords) } /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( + def newAPIHadoopFile[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]]( path: String, fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration, + cloneRecords: Boolean = true): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration - new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) + new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf, cloneRecords) } /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( + def newAPIHadoopRDD[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]]( conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) + vClass: Class[V], + cloneRecords: Boolean = true): RDD[(K, V)] = { + new NewHadoopRDD(this, fClass, kClass, vClass, conf, cloneRecords) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, + def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V], - minSplits: Int + minSplits: Int, + cloneRecords: Boolean = true ): RDD[(K, V)] = { val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] - hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) + hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = - sequenceFile(path, keyClass, valueClass, defaultMinSplits) + def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V], + cloneRecords: Boolean = true): RDD[(K, V)] = + sequenceFile(path, keyClass, valueClass, defaultMinSplits, cloneRecords) /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -472,17 +501,18 @@ class SparkContext( * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. */ - def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassTag[K], vm: ClassTag[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) + def sequenceFile[K, V] + (path: String, minSplits: Int = defaultMinSplits, cloneRecords: Boolean = true) + (implicit km: ClassTag[K], vm: ClassTag[V], + kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], - vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits) - writables.map{case (k,v) => (kc.convert(k), vc.convert(v))} + vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits, cloneRecords) + writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } } /** @@ -517,15 +547,15 @@ class SparkContext( // Methods for creating shared variables /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the driver can access the accumulator's `value`. + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** - * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the driver can access the accumuable's `value`. + * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values + * with `+=`. Only the driver can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ @@ -538,14 +568,16 @@ class SparkContext( * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by * standard mutable collections. So you can use this with mutable Map, Set, etc. */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + (initialValue: R) = { val param = new GrowableAccumulableParam[R,T] new Accumulable(initialValue, param) } /** - * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. + * Broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. */ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) @@ -667,10 +699,10 @@ class SparkContext( key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (SparkHadoopUtil.get.isYarnMode()) { - // In order for this to work on yarn the user must specify the --addjars option to - // the client to upload the file into the distributed cache to make it show up in the - // current working directory. + if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") { + // In order for this to work in yarn standalone mode the user must specify the + // --addjars option to the client to upload the file into the distributed cache + // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { env.httpFileServer.addJar(new File(fileName)) @@ -754,8 +786,11 @@ class SparkContext( private[spark] def getCallSite(): String = { val callSite = getLocalProperty("externalCallSite") - if (callSite == null) return Utils.formatSparkCallSite - callSite + if (callSite == null) { + Utils.formatSparkCallSite + } else { + callSite + } } /** @@ -905,7 +940,7 @@ class SparkContext( */ private[spark] def clean[F <: AnyRef](f: F): F = { ClosureCleaner.clean(f) - return f + f } /** @@ -917,7 +952,7 @@ class SparkContext( val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) fs.mkdirs(path) - fs.getFileStatus(path).getPath().toString + fs.getFileStatus(path).getPath.toString } } @@ -1010,7 +1045,8 @@ object SparkContext { implicit def stringToText(s: String) = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = { + private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) + : ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], @@ -1033,7 +1069,9 @@ object SparkContext { implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get) - implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + implicit def bytesWritableConverter() = { + simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + } implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) @@ -1049,7 +1087,8 @@ object SparkContext { if (uri != null) { val uriStr = uri.toString if (uriStr.startsWith("jar:file:")) { - // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar + // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", + // so pull out the /path/foo.jar List(uriStr.substring("jar:file:".length, uriStr.indexOf('!'))) } else { Nil @@ -1072,7 +1111,7 @@ object SparkContext { * parameters that are passed as the default value of null, instead of throwing an exception * like SparkConf would. */ - private def updatedConf( + private[spark] def updatedConf( conf: SparkConf, master: String, appName: String, @@ -1203,7 +1242,7 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.get("spark.mesos.coarse", "false").toBoolean + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2e36ccb9a0..ed788560e7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -54,7 +54,11 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) { + val conf: SparkConf) extends Logging { + + // A mapping of thread ID to amount of memory used for shuffle in bytes + // All accesses should be manually synchronized + val shuffleMemoryMap = mutable.HashMap[Long, Long]() private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -128,16 +132,6 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", boundPort.toString) } - // set only if unset until now. - if (!conf.contains("spark.hostPort")) { - if (!isDriver){ - // unexpected - Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") - } - Utils.checkHost(hostname) - conf.set("spark.hostPort", hostname + ":" + boundPort) - } - val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -162,7 +156,7 @@ object SparkEnv extends Logging { actorSystem.actorOf(Props(newActor), name = name) } else { val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.get("spark.driver.port", "7077").toInt + val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 618d95015f..4e63117a51 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -134,28 +134,28 @@ class SparkHadoopWriter(@transient jobConf: JobConf) format = conf.value.getOutputFormat() .asInstanceOf[OutputFormat[AnyRef,AnyRef]] } - return format + format } private def getOutputCommitter(): OutputCommitter = { if (committer == null) { committer = conf.value.getOutputCommitter } - return committer + committer } private def getJobContext(): JobContext = { if (jobContext == null) { jobContext = newJobContext(conf.value, jID.value) } - return jobContext + jobContext } private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { taskContext = newTaskAttemptContext(conf.value, taID.value) } - return taskContext + taskContext } private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { @@ -182,19 +182,18 @@ object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) - return new JobID(jobtrackerID, id) + new JobID(jobtrackerID, id) } def createPathFromString(path: String, conf: JobConf): Path = { if (path == null) { throw new IllegalArgumentException("Output path is null") } - var outputPath = new Path(path) + val outputPath = new Path(path) val fs = outputPath.getFileSystem(conf) if (outputPath == null || fs == null) { throw new IllegalArgumentException("Incorrectly formatted output path") } - outputPath = outputPath.makeQualified(fs) - return outputPath + outputPath.makeQualified(fs) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index da30cf619a..b0dedc6f4e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -207,13 +207,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * e.g. for the array * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1,0,0 - * + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. - * buckets array must be at least two elements + * buckets array must be at least two elements * All NaN entries are treated the same. If you have a NaN bucket it must be * the maximum value of the last position and all NaN entries will be counted * in that bucket. @@ -225,6 +225,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { srdd.histogram(buckets.map(_.toDouble), evenBuckets) } + + /** Assign a name to this RDD */ + def setName(name: String): JavaDoubleRDD = { + srdd.setName(name) + this + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 55c87450ac..0fb7e195b3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -647,6 +647,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaRDD[(K, Long)] = { rdd.countApproxDistinctByKey(relativeSD, numPartitions) } + + /** Assign a name to this RDD */ + def setName(name: String): JavaPairRDD[K, V] = { + rdd.setName(name) + this + } } object JavaPairRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 037cd1c774..7d48ce01cf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -127,6 +127,12 @@ JavaRDDLike[T, JavaRDD[T]] { wrapRDD(rdd.subtract(other, p)) override def toString = rdd.toString + + /** Assign a name to this RDD */ + def setName(name: String): JavaRDD[T] = { + rdd.setName(name) + this + } } object JavaRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 924d8af060..ebbbbd8806 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -245,6 +245,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** + * Return an array that contains all of the elements in this RDD. + */ + def toArray(): JList[T] = collect() + + /** * Return an array that contains all of the elements in a specific partition of this RDD. */ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { @@ -455,4 +460,5 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def countApproxDistinct(relativeSD: Double = 0.05): Long = rdd.countApproxDistinct(relativeSD) + def name(): String = rdd.name } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e93b10fd7e..7a6f044965 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -425,6 +425,51 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearCallSite() { sc.clearCallSite() } + + /** + * Set a local property that affects jobs submitted from this thread, such as the + * Spark fair scheduler pool. + */ + def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value) + + /** + * Get a local property set in this thread, or null if it is missing. See + * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + */ + def getLocalProperty(key: String): String = sc.getLocalProperty(key) + + /** + * Assigns a group ID to all the jobs started by this thread until the group ID is set to a + * different value or cleared. + * + * Often, a unit of execution in an application consists of multiple Spark actions or jobs. + * Application programmers can use this method to group all those jobs together and give a + * group description. Once set, the Spark web UI will associate such jobs with this group. + * + * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * to cancel all running jobs in this group. For example, + * {{{ + * // In the main thread: + * sc.setJobGroup("some_job_to_cancel", "some job description"); + * rdd.map(...).count(); + * + * // In a separate thread: + * sc.cancelJobGroup("some_job_to_cancel"); + * }}} + */ + def setJobGroup(groupId: String, description: String): Unit = sc.setJobGroup(groupId, description) + + /** Clear the current thread's job group ID and its description. */ + def clearJobGroup(): Unit = sc.clearJobGroup() + + /** + * Cancel active jobs for the specified group. See + * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + */ + def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) + + /** Cancel all jobs that have been scheduled or are running. */ + def cancelAllJobs(): Unit = sc.cancelAllJobs() } object JavaSparkContext { @@ -436,5 +481,12 @@ object JavaSparkContext { * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. */ - def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray + def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray + + /** + * Find the JAR that contains the class of a particular object, to make it easy for users + * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in + * your driver program. + */ + def jarOfObject(obj: AnyRef): Array[String] = SparkContext.jarOfObject(obj).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 32cc70e8c9..82527fe663 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ private[spark] class PythonRDD[T: ClassTag]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.get("spark.buffer.size", "65536").toInt + val bufferSize = conf.getInt("spark.buffer.size", 65536) override def getPartitions = parent.partitions @@ -95,7 +95,7 @@ private[spark] class PythonRDD[T: ClassTag]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - return new Iterator[Array[Byte]] { + val stdoutIterator = new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { @@ -156,6 +156,7 @@ private[spark] class PythonRDD[T: ClassTag]( def hasNext = _nextObj.length != 0 } + stdoutIterator } val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) @@ -250,7 +251,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Utils.checkHost(serverHost, "Expected hostname") - val bufferSize = SparkEnv.get.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index db596d5fcc..0eacda3d7d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -92,8 +92,8 @@ private object HttpBroadcast extends Logging { def initialize(isDriver: Boolean, conf: SparkConf) { synchronized { if (!initialized) { - bufferSize = conf.get("spark.buffer.size", "65536").toInt - compress = conf.get("spark.broadcast.compress", "true").toBoolean + bufferSize = conf.getInt("spark.buffer.size", 65536) + compress = conf.getBoolean("spark.broadcast.compress", true) if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 9530938278..1d295c62bc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -180,7 +180,7 @@ extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.get("spark.broadcast.blockSize", "4096").toInt * 1024 + lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) @@ -203,16 +203,16 @@ extends Logging { } bais.close() - var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) tInfo.hasBlocks = blockNum - return tInfo + tInfo } def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], totalBytes: Int, totalBlocks: Int): T = { - var retByteArray = new Array[Byte](totalBytes) + val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala new file mode 100644 index 0000000000..e133893f6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Map +import scala.concurrent._ + +import akka.actor._ +import akka.pattern.ask +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.util.{AkkaUtils, Utils} +import akka.actor.Actor.emptyBehavior +import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} + +/** + * Proxy that relays messages to the driver. + */ +class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { + var masterActor: ActorSelection = _ + val timeout = AkkaUtils.askTimeout(conf) + + override def preStart() = { + masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + + println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") + + driverArgs.cmd match { + case "launch" => + // TODO: We could add an env variable here and intercept it in `sc.addJar` that would + // truncate filesystem paths similar to what YARN does. For now, we just require + // people call `addJar` assuming the jar is in the same directory. + val env = Map[String, String]() + System.getenv().foreach{case (k, v) => env(k) = v} + + val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" + val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ + driverArgs.driverOptions, env) + + val driverDescription = new DriverDescription( + driverArgs.jarUrl, + driverArgs.memory, + driverArgs.cores, + driverArgs.supervise, + command) + + masterActor ! RequestSubmitDriver(driverDescription) + + case "kill" => + val driverId = driverArgs.driverId + val killFuture = masterActor ! RequestKillDriver(driverId) + } + } + + /* Find out driver status then exit the JVM */ + def pollAndReportStatus(driverId: String) { + println(s"... waiting before polling master for driver state") + Thread.sleep(5000) + println("... polling master for driver state") + val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) + .mapTo[DriverStatusResponse] + val statusResponse = Await.result(statusFuture, timeout) + + statusResponse.found match { + case false => + println(s"ERROR: Cluster master did not recognize $driverId") + System.exit(-1) + case true => + println(s"State of $driverId is ${statusResponse.state.get}") + // Worker node, if present + (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { + case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => + println(s"Driver running on $hostPort ($id)") + case _ => + } + // Exception, if present + statusResponse.exception.map { e => + println(s"Exception from cluster was: $e") + System.exit(-1) + } + System.exit(0) + } + } + + override def receive = { + + case SubmitDriverResponse(success, driverId, message) => + println(message) + if (success) pollAndReportStatus(driverId.get) else System.exit(-1) + + case KillDriverResponse(driverId, success, message) => + println(message) + if (success) pollAndReportStatus(driverId) else System.exit(-1) + + case DisassociatedEvent(_, remoteAddress, _) => + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + System.exit(-1) + + case AssociationErrorEvent(cause, _, remoteAddress, _) => + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + println(s"Cause was: $cause") + System.exit(-1) + } +} + +/** + * Executable utility for starting and terminating drivers inside of a standalone cluster. + */ +object Client { + def main(args: Array[String]) { + val conf = new SparkConf() + val driverArgs = new ClientArguments(args) + + if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { + conf.set("spark.akka.logLifecycleEvents", "true") + } + conf.set("spark.akka.askTimeout", "10") + conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) + Logger.getRootLogger.setLevel(driverArgs.logLevel) + + // TODO: See if we can initialize akka so return messages are sent back using the same TCP + // flow. Else, this (sadly) requires the DriverClient be routable from the Master. + val (actorSystem, _) = AkkaUtils.createActorSystem( + "driverClient", Utils.localHostName(), 0, false, conf) + + actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + + actorSystem.awaitTermination() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala new file mode 100644 index 0000000000..db67c6d1bb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.net.URL + +import scala.collection.mutable.ListBuffer + +import org.apache.log4j.Level + +/** + * Command-line parser for the driver client. + */ +private[spark] class ClientArguments(args: Array[String]) { + val defaultCores = 1 + val defaultMemory = 512 + + var cmd: String = "" // 'launch' or 'kill' + var logLevel = Level.WARN + + // launch parameters + var master: String = "" + var jarUrl: String = "" + var mainClass: String = "" + var supervise: Boolean = false + var memory: Int = defaultMemory + var cores: Int = defaultCores + private var _driverOptions = ListBuffer[String]() + def driverOptions = _driverOptions.toSeq + + // kill parameters + var driverId: String = "" + + parse(args.toList) + + def parse(args: List[String]): Unit = args match { + case ("--cores" | "-c") :: value :: tail => + cores = value.toInt + parse(tail) + + case ("--memory" | "-m") :: value :: tail => + memory = value.toInt + parse(tail) + + case ("--supervise" | "-s") :: tail => + supervise = true + parse(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case ("--verbose" | "-v") :: tail => + logLevel = Level.INFO + parse(tail) + + case "launch" :: _master :: _jarUrl :: _mainClass :: tail => + cmd = "launch" + + try { + new URL(_jarUrl) + } catch { + case e: Exception => + println(s"Jar url '${_jarUrl}' is not a valid URL.") + println(s"Jar must be in URL format (e.g. hdfs://XX, file://XX)") + printUsageAndExit(-1) + } + + jarUrl = _jarUrl + master = _master + mainClass = _mainClass + _driverOptions ++= tail + + case "kill" :: _master :: _driverId :: tail => + cmd = "kill" + master = _master + driverId = _driverId + + case _ => + printUsageAndExit(1) + } + + /** + * Print usage and exit JVM with the given exit code. + */ + def printUsageAndExit(exitCode: Int) { + // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars + // separately similar to in the YARN client. + val usage = + s""" + |Usage: DriverClient [options] launch <active-master> <jar-url> <main-class> [driver options] + |Usage: DriverClient kill <active-master> <driver-id> + | + |Options: + | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -s, --supervise Whether to restart the driver on failure + | -v, --verbose Print more debugging output + """.stripMargin + System.err.println(usage) + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 275331724a..5e824e1a67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -20,12 +20,12 @@ package org.apache.spark.deploy import scala.collection.immutable.List import org.apache.spark.deploy.ExecutorState.ExecutorState -import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo} +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.util.Utils - private[deploy] sealed trait DeployMessage extends Serializable /** Contains messages sent between Scheduler actor nodes. */ @@ -54,7 +54,14 @@ private[deploy] object DeployMessages { exitStatus: Option[Int]) extends DeployMessage - case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription]) + case class DriverStateChanged( + driverId: String, + state: DriverState, + exception: Option[Exception]) + extends DeployMessage + + case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], + driverIds: Seq[String]) case class Heartbeat(workerId: String) extends DeployMessage @@ -76,14 +83,18 @@ private[deploy] object DeployMessages { sparkHome: String) extends DeployMessage - // Client to Master + case class LaunchDriver(driverId: String, driverDesc: DriverDescription) extends DeployMessage + + case class KillDriver(driverId: String) extends DeployMessage + + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage case class MasterChangeAcknowledged(appId: String) - // Master to Client + // Master to AppClient case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage @@ -97,11 +108,28 @@ private[deploy] object DeployMessages { case class ApplicationRemoved(message: String) - // Internal message in Client + // DriverClient <-> Master + + case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage + + case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + extends DeployMessage + + case class RequestKillDriver(driverId: String) extends DeployMessage + + case class KillDriverResponse(driverId: String, success: Boolean, message: String) + extends DeployMessage + + case class RequestDriverStatus(driverId: String) extends DeployMessage + + case class DriverStatusResponse(found: Boolean, state: Option[DriverState], + workerId: Option[String], workerHostPort: Option[String], exception: Option[Exception]) + + // Internal message in AppClient - case object StopClient + case object StopAppClient - // Master to Worker & Client + // Master to Worker & AppClient case class MasterChanged(masterUrl: String, masterWebUiUrl: String) @@ -113,6 +141,7 @@ private[deploy] object DeployMessages { case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], status: MasterState) { Utils.checkHost(host, "Required hostname") @@ -128,14 +157,15 @@ private[deploy] object DeployMessages { // Worker to WorkerWebUI case class WorkerStateResponse(host: String, port: Int, workerId: String, - executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, + executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], + drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { Utils.checkHost(host, "Required hostname") assert (port > 0) } - // Actor System to Worker + // Liveness checks in various places case object SendHeartbeat } diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala new file mode 100644 index 0000000000..58c95dc4f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +private[spark] class DriverDescription( + val jarUrl: String, + val mem: Int, + val cores: Int, + val supervise: Boolean, + val command: Command) + extends Serializable { + + override def toString: String = s"DriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 481026eaa2..1415e2f3d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -33,16 +33,17 @@ import org.apache.spark.deploy.master.Master import org.apache.spark.util.AkkaUtils /** - * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, - * and a listener for cluster events, and calls back the listener when various events occur. + * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, + * an app description, and a listener for cluster events, and calls back the listener when various + * events occur. * * @param masterUrls Each url should look like spark://host:port. */ -private[spark] class Client( +private[spark] class AppClient( actorSystem: ActorSystem, masterUrls: Array[String], appDescription: ApplicationDescription, - listener: ClientListener, + listener: AppClientListener, conf: SparkConf) extends Logging { @@ -155,7 +156,7 @@ private[spark] class Client( case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") - case StopClient => + case StopAppClient => markDead() sender ! true context.stop(self) @@ -188,7 +189,7 @@ private[spark] class Client( if (actor != null) { try { val timeout = AkkaUtils.askTimeout(conf) - val future = actor.ask(StopClient)(timeout) + val future = actor.ask(StopAppClient)(timeout) Await.result(future, timeout) } catch { case e: TimeoutException => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala index be7a11bd15..55d4ef1b31 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala @@ -24,7 +24,7 @@ package org.apache.spark.deploy.client * * Users of this API should *not* block inside the callback methods. */ -private[spark] trait ClientListener { +private[spark] trait AppClientListener { def connected(appId: String): Unit /** Disconnection may be a temporary state, as we fail over to a new Master. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 28ebbdc66b..ffa909c26b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { - class TestListener extends ClientListener with Logging { + class TestListener extends AppClientListener with Logging { def connected(id: String) { logInfo("Connected to master, got app ID " + id) } @@ -51,7 +51,7 @@ private[spark] object TestClient { "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") val listener = new TestListener - val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala new file mode 100644 index 0000000000..33377931d6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import java.util.Date + +import org.apache.spark.deploy.DriverDescription + +private[spark] class DriverInfo( + val startTime: Long, + val id: String, + val desc: DriverDescription, + val submitDate: Date) + extends Serializable { + + @transient var state: DriverState.Value = DriverState.SUBMITTED + /* If we fail when launching the driver, the exception is stored here. */ + @transient var exception: Option[Exception] = None + /* Most recent worker assigned to this driver */ + @transient var worker: Option[WorkerInfo] = None +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala new file mode 100644 index 0000000000..26a68bade3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +private[spark] object DriverState extends Enumeration { + + type DriverState = Value + + // SUBMITTED: Submitted but not yet scheduled on a worker + // RUNNING: Has been allocated to a worker to run + // FINISHED: Previously ran and exited cleanly + // RELAUNCHING: Exited non-zero or due to worker failure, but has not yet started running again + // UNKNOWN: The state of the driver is temporarily not known due to master failure recovery + // KILLED: A user manually killed this driver + // FAILED: The driver exited non-zero and was not supervised + // ERROR: Unable to run or restart due to an unrecoverable error (e.g. missing jar file) + val SUBMITTED, RUNNING, FINISHED, RELAUNCHING, UNKNOWN, KILLED, FAILED, ERROR = Value +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 043945a211..74bb9ebf1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.master import java.io._ -import scala.Serializable - import akka.serialization.Serialization import org.apache.spark.Logging @@ -47,6 +45,15 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + "app_" + app.id).delete() } + override def addDriver(driver: DriverInfo) { + val driverFile = new File(dir + File.separator + "driver_" + driver.id) + serializeIntoFile(driverFile, driver) + } + + override def removeDriver(driver: DriverInfo) { + new File(dir + File.separator + "driver_" + driver.id).delete() + } + override def addWorker(worker: WorkerInfo) { val workerFile = new File(dir + File.separator + "worker_" + worker.id) serializeIntoFile(workerFile, worker) @@ -56,13 +63,15 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + "worker_" + worker.id).delete() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { val sortedFiles = new File(dir).listFiles().sortBy(_.getName) val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_")) + val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, workers) + (apps, drivers, workers) } private def serializeIntoFile(file: File, value: AnyRef) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 6617b7100f..d9ea96afcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -23,19 +23,22 @@ import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Await import scala.concurrent.duration._ +import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} + +import org.apache.spark.{SparkConf, Logging, SparkException} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.master.DriverState.DriverState private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -43,13 +46,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val conf = new SparkConf val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = conf.get("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = conf.get("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = conf.get("spark.dead.worker.persistence", "15").toInt + val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") - var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] val actorToWorker = new HashMap[ActorRef, WorkerInfo] @@ -59,9 +61,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val idToApp = new HashMap[String, ApplicationInfo] val actorToApp = new HashMap[ActorRef, ApplicationInfo] val addressToApp = new HashMap[Address, ApplicationInfo] - val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] + var nextAppNumber = 0 + + val drivers = new HashSet[DriverInfo] + val completedDrivers = new ArrayBuffer[DriverInfo] + val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling + var nextDriverNumber = 0 Utils.checkHost(host, "Expected hostname") @@ -142,14 +149,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act override def receive = { case ElectedLeader => { - val (storedApps, storedWorkers) = persistenceEngine.readPersistedData() - state = if (storedApps.isEmpty && storedWorkers.isEmpty) + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) RecoveryState.ALIVE else RecoveryState.RECOVERING logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { - beginRecovery(storedApps, storedWorkers) + beginRecovery(storedApps, storedDrivers, storedWorkers) context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } } } @@ -176,6 +183,69 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." + sender ! SubmitDriverResponse(false, None, msg) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + sender ! SubmitDriverResponse(true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}") + } + } + + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"Can only kill drivers in ALIVE state. Current state: $state." + sender ! KillDriverResponse(driverId, success = false, msg) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self ! DriverStateChanged(driverId, DriverState.KILLED, None) + } + else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.actor ! KillDriver(driverId) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + sender ! KillDriverResponse(driverId, success = true, msg) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + sender ! KillDriverResponse(driverId, success = false, msg) + } + } + } + + case RequestDriverStatus(driverId) => { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + sender ! DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + case None => + sender ! DriverStatusResponse(found = false, None, None, None, None) + } + } + case RegisterApplication(description) => { if (state == RecoveryState.STANDBY) { // ignore, don't send response @@ -218,6 +288,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + case DriverStateChanged(driverId, state, exception) => { + state match { + case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => + removeDriver(driverId, state, exception) + case _ => + throw new Exception(s"Received unexpected state update for driver $driverId: $state") + } + } + case Heartbeat(workerId) => { idToWorker.get(workerId) match { case Some(workerInfo) => @@ -239,7 +318,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act if (canCompleteRecovery) { completeRecovery() } } - case WorkerSchedulerStateResponse(workerId, executors) => { + case WorkerSchedulerStateResponse(workerId, executors, driverIds) => { idToWorker.get(workerId) match { case Some(worker) => logInfo("Worker has been re-registered: " + workerId) @@ -252,6 +331,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act worker.addExecutor(execInfo) execInfo.copyState(exec) } + + for (driverId <- driverIds) { + drivers.find(_.id == driverId).foreach { driver => + driver.worker = Some(worker) + driver.state = DriverState.RUNNING + worker.drivers(driverId) = driver + } + } case None => logWarning("Scheduler state from unknown worker: " + workerId) } @@ -269,7 +356,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act case RequestMasterState => { sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - state) + drivers.toArray, completedDrivers.toArray, state) } case CheckForWorkerTimeOut => { @@ -285,7 +372,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 - def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) { + def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], + storedWorkers: Seq[WorkerInfo]) { for (app <- storedApps) { logInfo("Trying to recover app: " + app.id) try { @@ -297,6 +385,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + for (driver <- storedDrivers) { + // Here we just read in the list of drivers. Any drivers associated with now-lost workers + // will be re-launched when we detect that the worker is missing. + drivers += driver + } + for (worker <- storedWorkers) { logInfo("Trying to recover worker: " + worker.id) try { @@ -320,6 +414,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + // Reschedule drivers which were not claimed by any workers + drivers.filter(_.worker.isEmpty).foreach { d => + logWarning(s"Driver ${d.id} was not found after master recovery") + if (d.desc.supervise) { + logWarning(s"Re-launching ${d.id}") + relaunchDriver(d) + } else { + removeDriver(d.id, DriverState.ERROR, None) + logWarning(s"Did not re-launch ${d.id} because it was not supervised") + } + } + state = RecoveryState.ALIVE schedule() logInfo("Recovery complete - resuming operations!") @@ -340,6 +446,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act */ def schedule() { if (state != RecoveryState.ALIVE) { return } + + // First schedule drivers, they take strict precedence over applications + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver + } + } + } + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { @@ -426,9 +544,25 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act exec.id, ExecutorState.LOST, Some("worker lost"), None) exec.application.removeExecutor(exec) } + for (driver <- worker.drivers.values) { + if (driver.desc.supervise) { + logInfo(s"Re-launching ${driver.id}") + relaunchDriver(driver) + } else { + logInfo(s"Not re-launching ${driver.id} because it was not supervised") + removeDriver(driver.id, DriverState.ERROR, None) + } + } persistenceEngine.removeWorker(worker) } + def relaunchDriver(driver: DriverInfo) { + driver.worker = None + driver.state = DriverState.RELAUNCHING + waitingDrivers += driver + schedule() + } + def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) @@ -508,6 +642,41 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } } + + def newDriverId(submitDate: Date): String = { + val appId = "driver-%s-%04d".format(DATE_FORMAT.format(submitDate), nextDriverNumber) + nextDriverNumber += 1 + appId + } + + def createDriver(desc: DriverDescription): DriverInfo = { + val now = System.currentTimeMillis() + val date = new Date(now) + new DriverInfo(now, newDriverId(date), desc, date) + } + + def launchDriver(worker: WorkerInfo, driver: DriverInfo) { + logInfo("Launching driver " + driver.id + " on worker " + worker.id) + worker.addDriver(driver) + driver.worker = Some(worker) + worker.actor ! LaunchDriver(driver.id, driver.desc) + driver.state = DriverState.RUNNING + } + + def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) { + drivers.find(d => d.id == driverId) match { + case Some(driver) => + logInfo(s"Removing driver: $driverId") + drivers -= driver + completedDrivers += driver + persistenceEngine.removeDriver(driver) + driver.state = finalState + driver.exception = exception + driver.worker.foreach(w => w.removeDriver(driver)) + case None => + logWarning(s"Asked to remove unknown driver: $driverId") + } + } } private[spark] object Master { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 94b986caf2..e3640ea4f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -35,11 +35,15 @@ private[spark] trait PersistenceEngine { def removeWorker(worker: WorkerInfo) + def addDriver(driver: DriverInfo) + + def removeDriver(driver: DriverInfo) + /** * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) + def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) def close() {} } @@ -49,5 +53,8 @@ private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { override def removeApplication(app: ApplicationInfo) {} override def addWorker(worker: WorkerInfo) {} override def removeWorker(worker: WorkerInfo) {} - override def readPersistedData() = (Nil, Nil) + override def addDriver(driver: DriverInfo) {} + override def removeDriver(driver: DriverInfo) {} + + override def readPersistedData() = (Nil, Nil, Nil) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e05f587b58..c5fa9cf7d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -17,8 +17,10 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef import scala.collection.mutable + +import akka.actor.ActorRef + import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -35,7 +37,8 @@ private[spark] class WorkerInfo( Utils.checkHost(host, "Expected hostname") assert (port > 0) - @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info + @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info + @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info @transient var state: WorkerState.Value = _ @transient var coresUsed: Int = _ @transient var memoryUsed: Int = _ @@ -54,6 +57,7 @@ private[spark] class WorkerInfo( private def init() { executors = new mutable.HashMap + drivers = new mutable.HashMap state = WorkerState.ALIVE coresUsed = 0 memoryUsed = 0 @@ -83,6 +87,18 @@ private[spark] class WorkerInfo( executors.values.exists(_.application == app) } + def addDriver(driver: DriverInfo) { + drivers(driver.id) = driver + memoryUsed += driver.desc.mem + coresUsed += driver.desc.cores + } + + def removeDriver(driver: DriverInfo) { + drivers -= driver.id + memoryUsed -= driver.desc.mem + coresUsed -= driver.desc.cores + } + def webUiAddress : String = { "http://" + this.publicAddress + ":" + this.webUiPort } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 52000d4f9c..f24f49ea8a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -49,6 +49,14 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) zk.delete(WORKING_DIR + "/app_" + app.id) } + override def addDriver(driver: DriverInfo) { + serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver) + } + + override def removeDriver(driver: DriverInfo) { + zk.delete(WORKING_DIR + "/driver_" + driver.id) + } + override def addWorker(worker: WorkerInfo) { serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) } @@ -61,13 +69,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) zk.close() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted val appFiles = sortedFiles.filter(_.startsWith("app_")) val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val driverFiles = sortedFiles.filter(_.startsWith("driver_")) + val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) val workerFiles = sortedFiles.filter(_.startsWith("worker_")) val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, workers) + (apps, drivers, workers) } private def serializeIntoFile(path: String, value: AnyRef) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index dbb0cb90f5..9485bfd89e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -67,11 +67,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) { <li><strong>User:</strong> {app.desc.user}</li> <li><strong>Cores:</strong> { - if (app.desc.maxCores == Integer.MAX_VALUE) { + if (app.desc.maxCores == None) { "Unlimited (%s granted)".format(app.coresGranted) } else { "%s (%s granted, %s left)".format( - app.desc.maxCores, app.coresGranted, app.coresLeft) + app.desc.maxCores.get, app.coresGranted, app.coresLeft) } } </li> diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index 4ef762892c..a9af8df552 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master.ui import scala.concurrent.Await +import scala.concurrent.duration._ import scala.xml.Node import akka.pattern.ask @@ -26,7 +27,7 @@ import net.liftweb.json.JsonAST.JValue import org.apache.spark.deploy.{DeployWebUI, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo} +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils @@ -56,6 +57,16 @@ private[spark] class IndexPage(parent: MasterWebUI) { val completedApps = state.completedApps.sortBy(_.endTime).reverse val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) + val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") + val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse + val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) + val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse + val completedDriversTable = UIUtils.listingTable(driverHeaders, driverRow, completedDrivers) + + // For now we only show driver information if the user has submitted drivers to the cluster. + // This is until we integrate the notion of drivers and applications in the UI. + def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + val content = <div class="row-fluid"> <div class="span12"> @@ -70,6 +81,9 @@ private[spark] class IndexPage(parent: MasterWebUI) { <li><strong>Applications:</strong> {state.activeApps.size} Running, {state.completedApps.size} Completed </li> + <li><strong>Drivers:</strong> + {state.activeDrivers.size} Running, + {state.completedDrivers.size} Completed </li> </ul> </div> </div> @@ -84,17 +98,39 @@ private[spark] class IndexPage(parent: MasterWebUI) { <div class="row-fluid"> <div class="span12"> <h4> Running Applications </h4> - {activeAppsTable} </div> </div> + <div> + {if (hasDrivers) + <div class="row-fluid"> + <div class="span12"> + <h4> Running Drivers </h4> + {activeDriversTable} + </div> + </div> + } + </div> + <div class="row-fluid"> <div class="span12"> <h4> Completed Applications </h4> {completedAppsTable} </div> + </div> + + <div> + {if (hasDrivers) + <div class="row-fluid"> + <div class="span12"> + <h4> Completed Drivers </h4> + {completedDriversTable} + </div> + </div> + } </div>; + UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) } @@ -134,4 +170,20 @@ private[spark] class IndexPage(parent: MasterWebUI) { <td>{DeployWebUI.formatDuration(app.duration)}</td> </tr> } + + def driverRow(driver: DriverInfo): Seq[Node] = { + <tr> + <td>{driver.id} </td> + <td>{driver.submitDate}</td> + <td>{driver.worker.map(w => <a href={w.webUiAddress}>{w.id.toString}</a>).getOrElse("None")}</td> + <td>{driver.state}</td> + <td sorttable_customkey={driver.desc.cores.toString}> + {driver.desc.cores} + </td> + <td sorttable_customkey={driver.desc.mem.toString}> + {Utils.megabytesToString(driver.desc.mem.toLong)} + </td> + <td>{driver.desc.command.arguments(1)}</td> + </tr> + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala new file mode 100644 index 0000000000..7507bf8ad0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -0,0 +1,63 @@ +package org.apache.spark.deploy.worker + +import java.io.{File, FileOutputStream, IOException, InputStream} +import java.lang.System._ + +import org.apache.spark.Logging +import org.apache.spark.deploy.Command +import org.apache.spark.util.Utils + +/** + ** Utilities for running commands with the spark classpath. + */ +object CommandUtils extends Logging { + private[spark] def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { + val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java") + + // SPARK-698: do not call the run.cmd script, as process.destroy() + // fails to kill a process tree on Windows + Seq(runner) ++ buildJavaOpts(command, memory, sparkHome) ++ Seq(command.mainClass) ++ + command.arguments + } + + private def getEnv(key: String, command: Command): Option[String] = + command.environment.get(key).orElse(Option(System.getenv(key))) + + /** + * Attention: this must always be aligned with the environment variables in the run scripts and + * the way the JAVA_OPTS are assembled there. + */ + def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { + val libraryOpts = getEnv("SPARK_LIBRARY_PATH", command) + .map(p => List("-Djava.library.path=" + p)) + .getOrElse(Nil) + val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) + val userOpts = getEnv("SPARK_JAVA_OPTS", command).map(Utils.splitCommandString).getOrElse(Nil) + val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") + + // Figure out our classpath with the external compute-classpath script + val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" + val classPath = Utils.executeAndGetOutput( + Seq(sparkHome + "/bin/compute-classpath" + ext), + extraEnvironment=command.environment) + + Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts + } + + /** Spawn a thread that will redirect a given stream to a file */ + def redirectStream(in: InputStream, file: File) { + val out = new FileOutputStream(file, true) + // TODO: It would be nice to add a shutdown hook here that explains why the output is + // terminating. Otherwise if the worker dies the executor logs will silently stop. + new Thread("redirect output to " + file) { + override def run() { + try { + Utils.copyStream(in, out, true) + } catch { + case e: IOException => + logInfo("Redirection to " + file + " closed: " + e.getMessage) + } + } + }.start() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala new file mode 100644 index 0000000000..b4df1a0dd4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import java.io._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Map + +import akka.actor.ActorRef +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileUtil, Path} + +import org.apache.spark.Logging +import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.DeployMessages.DriverStateChanged +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.master.DriverState.DriverState + +/** + * Manages the execution of one driver, including automatically restarting the driver on failure. + */ +private[spark] class DriverRunner( + val driverId: String, + val workDir: File, + val sparkHome: File, + val driverDesc: DriverDescription, + val worker: ActorRef, + val workerUrl: String) + extends Logging { + + @volatile var process: Option[Process] = None + @volatile var killed = false + + // Populated once finished + var finalState: Option[DriverState] = None + var finalException: Option[Exception] = None + var finalExitCode: Option[Int] = None + + // Decoupled for testing + private[deploy] def setClock(_clock: Clock) = clock = _clock + private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper + private var clock = new Clock { + def currentTimeMillis(): Long = System.currentTimeMillis() + } + private var sleeper = new Sleeper { + def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) + } + + /** Starts a thread to run and manage the driver. */ + def start() = { + new Thread("DriverRunner for " + driverId) { + override def run() { + try { + val driverDir = createWorkingDirectory() + val localJarFilename = downloadUserJar(driverDir) + + // Make sure user application jar is on the classpath + // TODO: If we add ability to submit multiple jars they should also be added here + val env = Map(driverDesc.command.environment.toSeq: _*) + env("SPARK_CLASSPATH") = env.getOrElse("SPARK_CLASSPATH", "") + s":$localJarFilename" + val newCommand = Command(driverDesc.command.mainClass, + driverDesc.command.arguments.map(substituteVariables), env) + val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, + sparkHome.getAbsolutePath) + launchDriver(command, env, driverDir, driverDesc.supervise) + } + catch { + case e: Exception => finalException = Some(e) + } + + val state = + if (killed) { DriverState.KILLED } + else if (finalException.isDefined) { DriverState.ERROR } + else { + finalExitCode match { + case Some(0) => DriverState.FINISHED + case _ => DriverState.FAILED + } + } + + finalState = Some(state) + + worker ! DriverStateChanged(driverId, state, finalException) + } + }.start() + } + + /** Terminate this driver (or prevent it from ever starting if not yet started) */ + def kill() { + synchronized { + process.foreach(p => p.destroy()) + killed = true + } + } + + /** Replace variables in a command argument passed to us */ + private def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case other => other + } + + /** + * Creates the working directory for this driver. + * Will throw an exception if there are errors preparing the directory. + */ + private def createWorkingDirectory(): File = { + val driverDir = new File(workDir, driverId) + if (!driverDir.exists() && !driverDir.mkdirs()) { + throw new IOException("Failed to create directory " + driverDir) + } + driverDir + } + + /** + * Download the user jar into the supplied directory and return its local path. + * Will throw an exception if there are errors downloading the jar. + */ + private def downloadUserJar(driverDir: File): String = { + + val jarPath = new Path(driverDesc.jarUrl) + + val emptyConf = new Configuration() + val jarFileSystem = jarPath.getFileSystem(emptyConf) + + val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) + val jarFileName = jarPath.getName + val localJarFile = new File(driverDir, jarFileName) + val localJarFilename = localJarFile.getAbsolutePath + + if (!localJarFile.exists()) { // May already exist if running multiple workers on one node + logInfo(s"Copying user jar $jarPath to $destPath") + FileUtil.copy(jarFileSystem, jarPath, destPath, false, emptyConf) + } + + if (!localJarFile.exists()) { // Verify copy succeeded + throw new Exception(s"Did not see expected jar $jarFileName in $driverDir") + } + + localJarFilename + } + + private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File, + supervise: Boolean) { + val builder = new ProcessBuilder(command: _*).directory(baseDir) + envVars.map{ case(k,v) => builder.environment().put(k, v) } + + def initialize(process: Process) = { + // Redirect stdout and stderr to files + val stdout = new File(baseDir, "stdout") + CommandUtils.redirectStream(process.getInputStream, stdout) + + val stderr = new File(baseDir, "stderr") + val header = "Launch Command: %s\n%s\n\n".format( + command.mkString("\"", "\" \"", "\""), "=" * 40) + Files.append(header, stderr, Charsets.UTF_8) + CommandUtils.redirectStream(process.getErrorStream, stderr) + } + runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) + } + + private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, + supervise: Boolean) { + // Time to wait between submission retries. + var waitSeconds = 1 + // A run of this many seconds resets the exponential back-off. + val successfulRunDuration = 5 + + var keepTrying = !killed + + while (keepTrying) { + logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\"")) + + synchronized { + if (killed) { return } + process = Some(command.start()) + initialize(process.get) + } + + val processStart = clock.currentTimeMillis() + val exitCode = process.get.waitFor() + if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + waitSeconds = 1 + } + + if (supervise && exitCode != 0 && !killed) { + logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") + sleeper.sleep(waitSeconds) + waitSeconds = waitSeconds * 2 // exponential back-off + } + + keepTrying = supervise && exitCode != 0 && !killed + finalExitCode = Some(exitCode) + } + } +} + +private[deploy] trait Clock { + def currentTimeMillis(): Long +} + +private[deploy] trait Sleeper { + def sleep(seconds: Int) +} + +// Needed because ProcessBuilder is a final class and cannot be mocked +private[deploy] trait ProcessBuilderLike { + def start(): Process + def command: Seq[String] +} + +private[deploy] object ProcessBuilderLike { + def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { + def start() = processBuilder.start() + def command = processBuilder.command() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala new file mode 100644 index 0000000000..1640d5fee0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -0,0 +1,31 @@ +package org.apache.spark.deploy.worker + +import akka.actor._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * Utility object for launching driver programs such that they share fate with the Worker process. + */ +object DriverWrapper { + def main(args: Array[String]) { + args.toList match { + case workerUrl :: mainClass :: extraArgs => + val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + Utils.localHostName(), 0, false, new SparkConf()) + actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + + // Delegate to supplied main class + val clazz = Class.forName(args(1)) + val mainMethod = clazz.getMethod("main", classOf[Array[String]]) + mainMethod.invoke(null, extraArgs.toArray[String]) + + actorSystem.shutdown() + + case _ => + System.err.println("Usage: DriverWrapper <workerUrl> <driverMainClass> [options]") + System.exit(-1) + } + } +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff9cb60c7..18885d7ca6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -18,17 +18,15 @@ package org.apache.spark.deploy.worker import java.io._ -import java.lang.System.getenv import akka.actor.ActorRef import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.{Logging} -import org.apache.spark.deploy.{ExecutorState, ApplicationDescription} +import org.apache.spark.Logging +import org.apache.spark.deploy.{ExecutorState, ApplicationDescription, Command} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils /** * Manages the execution of one executor process. @@ -44,16 +42,17 @@ private[spark] class ExecutorRunner( val host: String, val sparkHome: File, val workDir: File, + val workerUrl: String, var state: ExecutorState.Value) extends Logging { val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null - var shutdownHook: Thread = null - private def getAppEnv(key: String): Option[String] = - appDesc.command.environment.get(key).orElse(Option(getenv(key))) + // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might + // make sense to remove this in the future. + var shutdownHook: Thread = null def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { @@ -92,55 +91,17 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString case other => other } - def buildCommandSeq(): Seq[String] = { - val command = appDesc.command - val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() - // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - (command.arguments ++ Seq(appId)).map(substituteVariables) - } - - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - def buildJavaOpts(): Seq[String] = { - val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH") - .map(p => List("-Djava.library.path=" + p)) - .getOrElse(Nil) - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) - val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil) - val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=appDesc.command.environment) - - Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts - } - - /** Spawn a thread that will redirect a given stream to a file */ - def redirectStream(in: InputStream, file: File) { - val out = new FileOutputStream(file, true) - new Thread("redirect output to " + file) { - override def run() { - try { - Utils.copyStream(in, out, true) - } catch { - case e: IOException => - logInfo("Redirection to " + file + " closed: " + e.getMessage) - } - } - }.start() + def getCommandSeq = { + val command = Command(appDesc.command.mainClass, + appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment) + CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) } /** @@ -155,7 +116,7 @@ private[spark] class ExecutorRunner( } // Launch the process - val command = buildCommandSeq() + val command = getCommandSeq logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) val builder = new ProcessBuilder(command: _*).directory(executorDir) val env = builder.environment() @@ -172,11 +133,11 @@ private[spark] class ExecutorRunner( // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") - redirectStream(process.getInputStream, stdout) + CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(executorDir, "stderr") Files.write(header, stderr, Charsets.UTF_8) - redirectStream(process.getErrorStream, stderr) + CommandUtils.redirectStream(process.getErrorStream, stderr) // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run // long-lived processes only. However, in the future, we might restart the executor a few diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index fcaf4e92b1..5182dcbb2a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -26,10 +26,12 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} @@ -44,6 +46,8 @@ private[spark] class Worker( cores: Int, memory: Int, masterUrls: Array[String], + actorSystemName: String, + actorName: String, workDirPath: String = null, val conf: SparkConf) extends Actor with Logging { @@ -55,7 +59,7 @@ private[spark] class Worker( val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = conf.get("spark.worker.timeout", "60").toLong * 1000 / 4 + val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -68,6 +72,7 @@ private[spark] class Worker( var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" + val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() @@ -75,6 +80,9 @@ private[spark] class Worker( var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] + val drivers = new HashMap[String, DriverRunner] + val finishedDrivers = new HashMap[String, DriverRunner] + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -185,7 +193,10 @@ private[spark] class Worker( val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList) + sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + + case Heartbeat => + logInfo(s"Received heartbeat from driver ${sender.path}") case RegisterWorkerFailed(message) => if (!registered) { @@ -199,7 +210,7 @@ private[spark] class Worker( } else { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING) + self, workerId, host, new File(execSparkHome_), workDir, akkaUrl, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -219,8 +230,8 @@ private[spark] class Worker( logInfo("Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + exitStatus.map(" exitStatus " + _).getOrElse("")) - finishedExecutors(fullId) = executor executors -= fullId + finishedExecutors(fullId) = executor coresUsed -= executor.cores memoryUsed -= executor.memory } @@ -239,13 +250,52 @@ private[spark] class Worker( } } + case LaunchDriver(driverId, driverDesc) => { + logInfo(s"Asked to launch driver $driverId") + val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + drivers(driverId) = driver + driver.start() + + coresUsed += driverDesc.cores + memoryUsed += driverDesc.mem + } + + case KillDriver(driverId) => { + logInfo(s"Asked to kill driver $driverId") + drivers.get(driverId) match { + case Some(runner) => + runner.kill() + case None => + logError(s"Asked to kill unknown driver $driverId") + } + } + + case DriverStateChanged(driverId, state, exception) => { + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + } + masterLock.synchronized { + master ! DriverStateChanged(driverId, state, exception) + } + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + case x: DisassociatedEvent if x.remoteAddress == masterAddress => logInfo(s"$x Disassociated !") masterDisconnected() case RequestWorkerState => { sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, activeMasterUrl, cores, memory, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl) } } @@ -282,10 +332,11 @@ private[spark] object Worker { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") + val actorName = "Worker" val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, workDir, conf), name = "Worker") + masterUrls, systemName, actorName, workDir, conf), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala new file mode 100644 index 0000000000..0e0d0cd626 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -0,0 +1,55 @@ +package org.apache.spark.deploy.worker + +import akka.actor.{Actor, Address, AddressFromURIString} +import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} + +import org.apache.spark.Logging +import org.apache.spark.deploy.DeployMessages.SendHeartbeat + +/** + * Actor which connects to a worker process and terminates the JVM if the connection is severed. + * Provides fate sharing between a worker and its associated child processes. + */ +private[spark] class WorkerWatcher(workerUrl: String) extends Actor + with Logging { + override def preStart() { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + + logInfo(s"Connecting to worker $workerUrl") + val worker = context.actorSelection(workerUrl) + worker ! SendHeartbeat // need to send a message here to initiate connection + } + + // Used to avoid shutting down JVM during tests + private[deploy] var isShutDown = false + private[deploy] def setTesting(testing: Boolean) = isTesting = testing + private var isTesting = false + + // Lets us filter events only from the worker's actor system + private val expectedHostPort = AddressFromURIString(workerUrl).hostPort + private def isWorker(address: Address) = address.hostPort == expectedHostPort + + def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) + + override def receive = { + case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + logInfo(s"Successfully connected to $workerUrl") + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + if isWorker(remoteAddress) => + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + + case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + // This log message will never be seen + logError(s"Lost connection to worker actor $workerUrl. Exiting.") + exitNonZero() + + case e: AssociationEvent => + // pass through association events relating to other remote actor systems + + case e => logWarning(s"Received unexpected actor system event: $e") + } +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index 0d59048313..925c6fb183 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -17,24 +17,20 @@ package org.apache.spark.deploy.worker.ui -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import scala.concurrent.duration._ import scala.concurrent.Await +import scala.xml.Node import akka.pattern.ask - +import javax.servlet.http.HttpServletRequest import net.liftweb.json.JsonAST.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils - private[spark] class IndexPage(parent: WorkerWebUI) { val workerActor = parent.worker.self val worker = parent.worker @@ -56,6 +52,16 @@ private[spark] class IndexPage(parent: WorkerWebUI) { val finishedExecutorTable = UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) + val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes") + val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse + val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers) + val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse + def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) + + // For now we only show driver information if the user has submitted drivers to the cluster. + // This is until we integrate the notion of drivers and applications in the UI. + def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0 + val content = <div class="row-fluid"> <!-- Worker Details --> <div class="span12"> @@ -79,11 +85,33 @@ private[spark] class IndexPage(parent: WorkerWebUI) { </div> </div> + <div> + {if (hasDrivers) + <div class="row-fluid"> <!-- Running Drivers --> + <div class="span12"> + <h4> Running Drivers {workerState.drivers.size} </h4> + {runningDriverTable} + </div> + </div> + } + </div> + <div class="row-fluid"> <!-- Finished Executors --> <div class="span12"> <h4> Finished Executors </h4> {finishedExecutorTable} </div> + </div> + + <div> + {if (hasDrivers) + <div class="row-fluid"> <!-- Finished Drivers --> + <div class="span12"> + <h4> Finished Drivers </h4> + {finishedDriverTable} + </div> + </div> + } </div>; UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( @@ -111,6 +139,27 @@ private[spark] class IndexPage(parent: WorkerWebUI) { .format(executor.appId, executor.execId)}>stderr</a> </td> </tr> + } + def driverRow(driver: DriverRunner): Seq[Node] = { + <tr> + <td>{driver.driverId}</td> + <td>{driver.driverDesc.command.arguments(1)}</td> + <td>{driver.finalState.getOrElse(DriverState.RUNNING)}</td> + <td sorttable_customkey={driver.driverDesc.cores.toString}> + {driver.driverDesc.cores.toString} + </td> + <td sorttable_customkey={driver.driverDesc.mem.toString}> + {Utils.megabytesToString(driver.driverDesc.mem)} + </td> + <td> + <a href={s"logPage?driverId=${driver.driverId}&logType=stdout"}>stdout</a> + <a href={s"logPage?driverId=${driver.driverId}&logType=stderr"}>stderr</a> + </td> + <td> + {driver.finalException.getOrElse("")} + </td> + </tr> + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index c382034c99..8daa47b2b2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -69,30 +69,48 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I def log(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") + + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) val logType = request.getParameter("logType") val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val path = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + s"${workDir.getPath}/$appId/$executorId/$logType" + case (None, None, Some(d)) => + s"${workDir.getPath}/$driverId/$logType" + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } val (startByte, endByte) = getByteRange(path, offset, byteLength) val file = new File(path) val logLength = file.length - val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n" - .format(startByte, endByte, logLength, appId, executorId, logType) + val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" pre + Utils.offsetBytes(path, startByte, endByte) } def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) val logType = request.getParameter("logType") val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val (path, params) = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") + case (None, None, Some(d)) => + (s"${workDir.getPath}/$d/$logType", s"driverId=$d") + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } val (startByte, endByte) = getByteRange(path, offset, byteLength) val file = new File(path) @@ -106,9 +124,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val backButton = if (startByte > 0) { - <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s" - .format(appId, executorId, logType, math.max(startByte-byteLength, 0), - byteLength)}> + <a href={"?%s&logType=%s&offset=%s&byteLength=%s" + .format(params, logType, math.max(startByte-byteLength, 0), byteLength)}> <button type="button" class="btn btn-default"> Previous {Utils.bytesToString(math.min(byteLength, startByte))} </button> @@ -122,8 +139,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val nextButton = if (endByte < logLength) { - <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s". - format(appId, executorId, logType, endByte, byteLength)}> + <a href={"?%s&logType=%s&offset=%s&byteLength=%s". + format(params, logType, endByte, byteLength)}> <button type="button" class="btn btn-default"> Next {Utils.bytesToString(math.min(byteLength, logLength-endByte))} </button> diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 53a2b94a52..45b43b403d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,8 +24,9 @@ import akka.remote._ import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.TaskState.TaskState +import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -91,7 +92,8 @@ private[spark] class CoarseGrainedExecutorBackend( } private[spark] object CoarseGrainedExecutorBackend { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + def run(driverUrl: String, executorId: String, hostname: String, cores: Int, + workerUrl: Option[String]) { // Debug code Utils.checkHost(hostname) @@ -101,21 +103,27 @@ private[spark] object CoarseGrainedExecutorBackend { indestructible = true, conf = new SparkConf) // set it val sparkHostPort = hostname + ":" + boundPort -// conf.set("spark.hostPort", sparkHostPort) actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") + workerUrl.foreach{ url => + actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + } actorSystem.awaitTermination() } def main(args: Array[String]) { - if (args.length < 4) { - //the reason we allow the last appid argument is to make it easy to kill rogue executors - System.err.println( - "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + - "[<appid>]") - System.exit(1) + args.length match { + case x if x < 4 => + System.err.println( + // Worker url is used in spark standalone mode to enforce fate-sharing with worker + "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> " + + "<cores> [<workerUrl>]") + System.exit(1) + case 4 => + run(args(0), args(1), args(2), args(3).toInt, None) + case x if x > 4 => + run(args(0), args(1), args(2), args(3).toInt, Some(args(4))) } - run(args(0), args(1), args(2), args(3).toInt) } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e51d274d33..7f31d7e6f8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -57,7 +57,7 @@ private[spark] class Executor( Utils.setCustomHostname(slaveHostname) // Set spark.* properties from executor arg - val conf = new SparkConf(false) + val conf = new SparkConf(true) conf.setAll(properties) // If we are in yarn mode, systems can have different disk layouts so we must set it @@ -279,6 +279,11 @@ private[spark] class Executor( //System.exit(1) } } finally { + // TODO: Unregister shuffle memory only for ShuffleMapTask + val shuffleMemoryMap = env.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap.remove(Thread.currentThread().getId) + } runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index a1e98845f6..5980177320 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -71,7 +71,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.get("spark.io.compression.snappy.block.size", "32768").toInt + val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index f736bb3713..fb4c65909a 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -46,7 +46,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } - if (size == 0 && gotChunkForSendingOnce == false) { + if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) gotChunkForSendingOnce = true diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 95cb0206ac..cba8477ed5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -330,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - return true + true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) @@ -385,7 +385,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } } // should not happen - to keep scala compiler happy - return true + true } // This is a hack to determine if remote socket was closed or not. @@ -559,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } // should not happen - to keep scala compiler happy - return true + true } def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 46c40d0a2a..e6e01783c8 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -54,22 +54,22 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.handler.threads.min", "20").toInt, - conf.get("spark.core.connection.handler.threads.max", "60").toInt, - conf.get("spark.core.connection.handler.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.handler.threads.min", 20), + conf.getInt("spark.core.connection.handler.threads.max", 60), + conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val handleReadWriteExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.io.threads.min", "4").toInt, - conf.get("spark.core.connection.io.threads.max", "32").toInt, - conf.get("spark.core.connection.io.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.io.threads.min", 4), + conf.getInt("spark.core.connection.io.threads.max", 32), + conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap private val handleConnectExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.connect.threads.min", "1").toInt, - conf.get("spark.core.connection.connect.threads.max", "8").toInt, - conf.get("spark.core.connection.connect.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.connect.threads.min", 1), + conf.getInt("spark.core.connection.connect.threads.max", 8), + conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index f2ecc6d439..2612884bdb 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -61,7 +61,7 @@ private[spark] object Message { if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) + new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) } def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = @@ -69,9 +69,9 @@ private[spark] object Message { def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { - return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) + createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) } else { - return createBufferMessage(Array(dataBuffer), ackId) + createBufferMessage(Array(dataBuffer), ackId) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index b729eb11c5..d87157e12c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -36,7 +36,7 @@ private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.get("spark.shuffle.netty.connect.timeout", "60000").toInt + val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) val fc = new FileClient(handler, connectTimeout) try { diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 546d921067..44204a8c46 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -64,7 +64,7 @@ private[spark] object ShuffleSender { val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) val file = new File(subDir, blockId.name) - return new FileSegment(file, 0, file.length()) + new FileSegment(file, 0, file.length()) } } val sender = new ShuffleSender(port, pResovler) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 6d4f46125f..83109d1a6f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -97,7 +97,7 @@ private[spark] object CheckpointRDD extends Logging { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") } - val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) @@ -131,7 +131,7 @@ private[spark] object CheckpointRDD extends Logging { ): Iterator[T] = { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) - val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4ba4696fef..a73714abca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} -import org.apache.spark.util.AppendOnlyMap - +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } - /** * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. @@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). + // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. + // CoGroupValue is the intermediate state of each value before being merged in compute. + private type CoGroup = ArrayBuffer[Any] + private type CoGroupValue = (Any, Int) // Int is dependency number + private type CoGroupCombiner = Seq[CoGroup] + + private val sparkConf = SparkEnv.get.conf private var serializerClass: String = null def setSerializer(cls: String): CoGroupedRDD[K] = { @@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) - } - - val getSeq = (k: K) => { - map.changeValue(k, update) - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + // A list of (rdd iterator, dependency number) pairs + val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } + val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + rddIterators += ((it, depNum)) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 + val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + rddIterators += ((it, depNum)) + } + } + + if (!externalSorting) { + val map = new AppendOnlyMap[K, CoGroupCombiner] + val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) + } + val getCombiner: K => CoGroupCombiner = key => { + map.changeValue(key, update) + } + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + getCombiner(kv._1)(depNum) += kv._2 } } + new InterruptibleIterator(context, map.iterator) + } else { + val map = createExternalMap(numRdds) + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + map.insert(kv._1, new CoGroupValue(kv._2, depNum)) + } + } + new InterruptibleIterator(context, map.iterator) + } + } + + private def createExternalMap(numRdds: Int) + : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { + + val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { + val newCombiner = Array.fill(numRdds)(new CoGroup) + value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner } - new InterruptibleIterator(context, map.iterator) + val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = + (combiner, value) => { + value match { case (v, depNum) => combiner(depNum) += v } + combiner + } + val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = + (combiner1, combiner2) => { + combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + } + new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( + createCombiner, mergeValue, mergeCombiners) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 98da35763b..cefcc3d2d9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -295,10 +295,10 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc val prefPartActual = prefPart.get - if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows - return minPowerOfTwo // prefer balance over locality - else { - return prefPartActual // prefer locality over balance + if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows + minPowerOfTwo // prefer balance over locality + } else { + prefPartActual // prefer locality over balance } } @@ -331,7 +331,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc */ def run(): Array[PartitionGroup] = { setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) - throwBalls() // assign partitions (balls) to each group (bins) + throwBalls() // assign partitions (balls) to each group (bins) getPartitions } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 53f77a38f5..5cdb80be1d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,10 @@ package org.apache.spark.rdd import java.io.EOFException -import org.apache.hadoop.mapred.FileInputFormat +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configuration, Configurable} +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf @@ -31,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.NextIterator -import org.apache.hadoop.conf.{Configuration, Configurable} +import org.apache.spark.util.Utils.cloneWritables /** @@ -42,14 +45,14 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp val inputSplit = new SerializableWritable[InputSplit](s) - override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt + override def hashCode(): Int = 41 * (41 + rddId) + idx override val index: Int = idx } /** * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3). + * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed @@ -61,15 +64,21 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. */ -class HadoopRDD[K, V]( +class HadoopRDD[K: ClassTag, V: ClassTag]( sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int) + minSplits: Int, + cloneRecords: Boolean) extends RDD[(K, V)](sc, Nil) with Logging { def this( @@ -78,7 +87,8 @@ class HadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int) = { + minSplits: Int, + cloneRecords: Boolean) = { this( sc, sc.broadcast(new SerializableWritable(conf)) @@ -87,7 +97,8 @@ class HadoopRDD[K, V]( inputFormatClass, keyClass, valueClass, - minSplits) + minSplits, + cloneRecords) } protected val jobConfCacheKey = "rdd_%d_job_conf".format(id) @@ -99,11 +110,11 @@ class HadoopRDD[K, V]( val conf: Configuration = broadcastedConf.value.value if (conf.isInstanceOf[JobConf]) { // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - return conf.asInstanceOf[JobConf] + conf.asInstanceOf[JobConf] } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { // getJobConf() has been called previously, so there is already a local cache of the JobConf // needed by this RDD. - return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). @@ -111,7 +122,7 @@ class HadoopRDD[K, V]( val newJobConf = new JobConf(broadcastedConf.value.value) initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - return newJobConf + newJobConf } } @@ -127,7 +138,7 @@ class HadoopRDD[K, V]( newInputFormat.asInstanceOf[Configurable].setConf(conf) } HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) - return newInputFormat + newInputFormat } override def getPartitions: Array[Partition] = { @@ -158,10 +169,10 @@ class HadoopRDD[K, V]( // Register an on-task-completion callback to close the input stream. context.addOnCompleteCallback{ () => closeIfNeeded() } - val key: K = reader.createKey() + val keyCloneFunc = cloneWritables[K](jobConf) val value: V = reader.createValue() - + val valueCloneFunc = cloneWritables[V](jobConf) override def getNext() = { try { finished = !reader.next(key, value) @@ -169,7 +180,11 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } - (key, value) + if (cloneRecords) { + (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable])) + } else { + (key, value) + } } override def close() { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 73d15b9082..992bd4aa0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -20,11 +20,14 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.util.Utils.cloneWritables private[spark] @@ -33,15 +36,31 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = (41 * (41 + rddId) + index) + override def hashCode(): Int = 41 * (41 + rddId) + index } -class NewHadoopRDD[K, V]( +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * + * @param sc The SparkContext to associate the RDD with. + * @param inputFormatClass Storage format of the data to be read. + * @param keyClass Class of the key associated with the inputFormatClass. + * @param valueClass Class of the value associated with the inputFormatClass. + * @param conf The Hadoop configuration. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. + */ +class NewHadoopRDD[K: ClassTag, V: ClassTag]( sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient conf: Configuration, + cloneRecords: Boolean) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -88,7 +107,8 @@ class NewHadoopRDD[K, V]( // Register an on-task-completion callback to close the input stream. context.addOnCompleteCallback(() => close()) - + val keyCloneFunc = cloneWritables[K](conf) + val valueCloneFunc = cloneWritables[V](conf) var havePair = false var finished = false @@ -105,7 +125,13 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - (reader.getCurrentKey, reader.getCurrentValue) + val key = reader.getCurrentKey + val value = reader.getCurrentValue + if (cloneRecords) { + (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable])) + } else { + (key, value) + } } private def close() { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 2bf7c5b8d6..f6719ec57c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -18,35 +18,34 @@ package org.apache.spark.rdd import java.nio.ByteBuffer -import java.util.Date import java.text.SimpleDateFormat +import java.util.Date import java.util.{HashMap => JHashMap} -import scala.collection.{mutable, Map} +import scala.collection.Map +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.{ClassTag, classTag} -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.FileOutputFormat -import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import com.clearspring.analytics.stream.cardinality.HyperLogLog +// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark. +import org.apache.hadoop.mapred.SparkHadoopWriter +import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} -import org.apache.spark.Aggregator -import org.apache.spark.Partitioner import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.util.SerializableHyperLogLog @@ -100,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) }, preservesPartitioning = true) } else { // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) values.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) @@ -120,9 +117,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { // Serialize the zero value to a byte array so that we can get a new clone of it on each key @@ -138,18 +135,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { foldByKey(zeroValue, defaultPartitioner(self))(func) @@ -226,7 +223,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Return approximate number of distinct values for each key in this RDD. + * Return approximate number of distinct values for each key in this RDD. * The accuracy of approximation can be controlled through the relative standard deviation * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in * more accurate counts but increase the memory footprint and vise versa. HashPartitions the @@ -268,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) + createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } @@ -340,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { + : RDD[(K, C)] = { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } @@ -579,7 +577,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec) + val runtimeClass = fm.runtimeClass + saveAsHadoopFile(path, getKeyClass, getValueClass, runtimeClass.asInstanceOf[Class[F]], codec) } /** @@ -599,7 +598,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) { + conf: Configuration = self.context.hadoopConfiguration) + { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) @@ -668,7 +668,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) - // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) conf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { conf.setCompressMapOutput(true) @@ -702,7 +704,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) throw new SparkException("Output value class not set") } - logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName+ ")") val writer = new SparkHadoopWriter(conf) writer.preSetup() diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 1dbbe39898..d4f396afb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -96,7 +96,7 @@ class PipedRDD[T: ClassTag]( // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream).getLines - return new Iterator[String] { + new Iterator[String] { def next() = lines.next() def hasNext = { if (lines.hasNext) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2142ae730e..cd90a1561a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -23,7 +23,6 @@ import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.BytesWritable @@ -52,11 +51,13 @@ import org.apache.spark._ * partitioned collection of elements that can be operated on in parallel. This class contains the * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, * [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs of key-value - * pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] contains - * operations available only on RDDs of Doubles; and [[org.apache.spark.rdd.SequenceFileRDDFunctions]] - * contains operations available on RDDs that can be saved as SequenceFiles. These operations are - * automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit - * conversions when you `import org.apache.spark.SparkContext._`. + * pairs, such as `groupByKey` and `join`; + * [[org.apache.spark.rdd.DoubleRDDFunctions]] contains operations available only on RDDs of + * Doubles; and + * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that + * can be saved as SequenceFiles. + * These operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] + * through implicit conversions when you `import org.apache.spark.SparkContext._`. * * Internally, each RDD is characterized by five main properties: * @@ -235,12 +236,9 @@ abstract class RDD[T: ClassTag]( /** * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */ - private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - firstParent[T].iterator(split, context) - } else { - compute(split, context) - } + private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = + { + if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) } // Transformations (return a new RDD) @@ -268,6 +266,9 @@ abstract class RDD[T: ClassTag]( def distinct(numPartitions: Int): RDD[T] = map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ def distinct(): RDD[T] = distinct(partitions.size) /** @@ -280,7 +281,7 @@ abstract class RDD[T: ClassTag]( * which can avoid performing a shuffle. */ def repartition(numPartitions: Int): RDD[T] = { - coalesce(numPartitions, true) + coalesce(numPartitions, shuffle = true) } /** @@ -651,7 +652,8 @@ abstract class RDD[T: ClassTag]( } /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. + * Reduces the elements of this RDD using the specified commutative and + * associative binary operator. */ def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) @@ -767,7 +769,7 @@ abstract class RDD[T: ClassTag]( val entry = iter.next() m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) } - return m1 + m1 } val myResult = mapPartitions(countPartition).reduce(mergeMaps) myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map @@ -845,7 +847,7 @@ abstract class RDD[T: ClassTag]( partsScanned += numPartsToTry } - return buf.toArray + buf.toArray } /** @@ -958,7 +960,7 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - @transient private[spark] val origin = sc.getCallSite + @transient private[spark] val origin = sc.getCallSite() private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 043e01dbfb..7046c06d20 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -106,7 +106,7 @@ class DAGScheduler( // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50.milliseconds + val RESUBMIT_TIMEOUT = 200.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages @@ -133,7 +133,8 @@ class DAGScheduler( private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] - private[spark] val listenerBus = new SparkListenerBus() + // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped. + private[spark] val listenerBus = new SparkListenerBus // Contains the locations that each RDD's partitions are cached on private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] @@ -196,7 +197,7 @@ class DAGScheduler( */ def receive = { case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) + logTrace("Got event of type " + event.getClass.getName) /** * All events are forwarded to `processEvent()`, so that the event processing logic can @@ -1121,5 +1122,6 @@ class DAGScheduler( } metadataCleaner.cancel() taskSched.stop() + listenerBus.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 90eb8a747f..cc10cc0849 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } - return retval.toSet + retval.toSet } // This method does not expect failures, since validate has already passed ... @@ -121,18 +121,18 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) ) - return retval.toSet + retval.toSet } private def findPreferredLocations(): Set[SplitInfo] = { logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + ", inputFormatClazz : " + inputFormatClazz) if (mapreduceInputFormat) { - return prefLocsFromMapreduceInputFormat() + prefLocsFromMapreduceInputFormat() } else { assert(mapredInputFormat) - return prefLocsFromMapredInputFormat() + prefLocsFromMapredInputFormat() } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 1791242215..4bc13c23d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -75,12 +75,12 @@ private[spark] class Pool( return schedulableNameToSchedulable(schedulableName) } for (schedulable <- schedulableQueue) { - var sched = schedulable.getSchedulableByName(schedulableName) + val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched } } - return null + null } override def executorLost(executorId: String, host: String) { @@ -92,7 +92,7 @@ private[spark] class Pool( for (schedulable <- schedulableQueue) { shouldRevive |= schedulable.checkSpeculatableTasks() } - return shouldRevive + shouldRevive } override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { @@ -101,7 +101,7 @@ private[spark] class Pool( for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() } - return sortedTaskSetQueue + sortedTaskSetQueue } def increaseRunningTasks(taskNum: Int) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 3418640b8c..5e62c8468f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -37,9 +37,9 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { res = math.signum(stageId1 - stageId2) } if (res < 0) { - return true + true } else { - return false + false } } } @@ -56,7 +56,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var res:Boolean = true var compare:Int = 0 if (s1Needy && !s2Needy) { @@ -70,11 +69,11 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { } if (compare < 0) { - return true + true } else if (compare > 0) { - return false + false } else { - return s1.name < s2.name + s1.name < s2.name } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 627995c826..55a40a92c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -43,6 +43,9 @@ case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], propertie case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) extends SparkListenerEvents +/** An event used in the listener to shutdown the listener daemon thread. */ +private[scheduler] case object SparkListenerShutdown extends SparkListenerEvents + trait SparkListener { /** * Called when a stage is completed, with information on the completed stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e7defd768b..17b1328b86 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -24,15 +24,17 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.Logging
/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
-private[spark] class SparkListenerBus() extends Logging {
- private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener]
+private[spark] class SparkListenerBus extends Logging {
+ private val sparkListeners = new ArrayBuffer[SparkListener] with SynchronizedBuffer[SparkListener]
/* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
* an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
+ private val EVENT_QUEUE_CAPACITY = 10000
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
+ // Create a new daemon thread to listen for events. This thread is stopped when it receives
+ // a SparkListenerShutdown event, using the stop method.
new Thread("SparkListenerBus") {
setDaemon(true)
override def run() {
@@ -53,6 +55,9 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case SparkListenerShutdown =>
+ // Get out of the while loop and shutdown the daemon thread
+ return
case _ =>
}
}
@@ -80,7 +85,7 @@ private[spark] class SparkListenerBus() extends Logging { */
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty()) {
+ while (!eventQueue.isEmpty) {
if (System.currentTimeMillis > finishTime) {
return false
}
@@ -88,6 +93,8 @@ private[spark] class SparkListenerBus() extends Logging { * add overhead in the general case. */
Thread.sleep(10)
}
- return true
+ true
}
+
+ def stop(): Unit = post(SparkListenerShutdown)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 7cb3fe46e5..c60e9896de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -96,7 +96,7 @@ private[spark] class Stage( def newAttemptId(): Int = { val id = nextAttemptId nextAttemptId += 1 - return id + id } val name = callSite.getOrElse(rdd.origin) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index e80cc6b0f6..9d3e615826 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -74,6 +74,6 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long def value(): T = { val resultSer = SparkEnv.get.serializer.newInstance() - return resultSer.deserialize(valueBytes) + resultSer.deserialize(valueBytes) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index e22b1e53e8..35e9544718 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -31,13 +31,13 @@ import org.apache.spark.util.Utils private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends Logging { - private val THREADS = sparkEnv.conf.get("spark.resultGetter.threads", "4").toInt + private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( THREADS, "Result resolver thread") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { - return sparkEnv.closureSerializer.newInstance() + sparkEnv.closureSerializer.newInstance() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c8ed62759..d4f74d3e18 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,15 +51,15 @@ private[spark] class TaskSchedulerImpl( isLocal: Boolean = false) extends TaskScheduler with Logging { - def this(sc: SparkContext) = this(sc, sc.conf.get("spark.task.maxFailures", "4").toInt) + def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4)) val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.get("spark.speculation.interval", "100").toLong + val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.get("spark.starvation.timeout", "15000").toLong + val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. @@ -125,7 +125,7 @@ private[spark] class TaskSchedulerImpl( override def start() { backend.start() - if (!isLocal && conf.get("spark.speculation", "false").toBoolean) { + if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 6dd1469d8f..fc0ee07089 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -57,11 +57,11 @@ private[spark] class TaskSetManager( val conf = sched.sc.conf // CPUs to request per task - val CPUS_PER_TASK = conf.get("spark.task.cpus", "1").toInt + val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = conf.get("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = conf.get("spark.speculation.multiplier", "1.5").toDouble + val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) + val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) // Serializer for closures and tasks. val env = SparkEnv.get @@ -116,7 +116,7 @@ private[spark] class TaskSetManager( // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = - conf.get("spark.logging.exceptionPrintInterval", "10000").toLong + conf.getLong("spark.logging.exceptionPrintInterval", 10000) // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception @@ -228,7 +228,7 @@ private[spark] class TaskSetManager( return Some(index) } } - return None + None } /** Check whether a task is currently running an attempt on a given host */ @@ -291,7 +291,7 @@ private[spark] class TaskSetManager( } } - return None + None } /** @@ -332,7 +332,7 @@ private[spark] class TaskSetManager( } // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) + findSpeculativeTask(execId, host, locality) } /** @@ -387,7 +387,7 @@ private[spark] class TaskSetManager( case _ => } } - return None + None } /** @@ -584,7 +584,7 @@ private[spark] class TaskSetManager( } override def getSchedulableByName(name: String): Schedulable = { - return null + null } override def addSchedulable(schedulable: Schedulable) {} @@ -594,7 +594,7 @@ private[spark] class TaskSetManager( override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) sortedTaskSetQueue += this - return sortedTaskSetQueue + sortedTaskSetQueue } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ @@ -669,7 +669,7 @@ private[spark] class TaskSetManager( } } } - return foundTasks + foundTasks } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2f5bcafe40..0208388e86 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -63,7 +63,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.get("spark.scheduler.revive.interval", "1000").toLong + val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -165,7 +165,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A override def start() { val properties = new ArrayBuffer[(String, String)] for ((key, value) <- scheduler.sc.conf.getAll) { - if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { + if (key.startsWith("spark.")) { properties += ((key, value)) } } @@ -209,8 +209,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } override def defaultParallelism(): Int = { - conf.getOption("spark.default.parallelism").map(_.toInt).getOrElse( - math.max(totalCoreCount.get(), 2)) + conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2)) } // Called by subclasses when notified of a lost worker diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index b44d1e43c8..d99c76117c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) - val maxCores = conf.get("spark.simr.executor.cores", "1").toInt + val maxCores = conf.getInt("spark.simr.executor.cores", 1) override def start() { super.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 73fc37444e..faa6e1ebe8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.HashMap import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.deploy.client.{Client, ClientListener} +import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} import org.apache.spark.util.Utils @@ -31,10 +31,10 @@ private[spark] class SparkDeploySchedulerBackend( masters: Array[String], appName: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) - with ClientListener + with AppClientListener with Logging { - var client: Client = null + var client: AppClient = null var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ @@ -47,14 +47,14 @@ private[spark] class SparkDeploySchedulerBackend( val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val command = Command( "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) - client = new Client(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d46fceba89..c27049bdb5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class CoarseMesosSchedulerBackend( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val extraCoresPerSlave = conf.get("spark.mesos.extra.cores", "0").toInt + val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) var nextMesosTaskId = 0 @@ -140,7 +140,7 @@ private[spark] class CoarseMesosSchedulerBackend( .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } - return command.build() + command.build() } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index ae8d527352..49781485d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -141,13 +141,13 @@ private[spark] class MesosSchedulerBackend( // Serialize the map as an array of (String, String) pairs execArgs = Utils.serialize(props.toArray) } - return execArgs + execArgs } private def setClassLoader(): ClassLoader = { val oldClassLoader = Thread.currentThread.getContextClassLoader Thread.currentThread.setContextClassLoader(classLoader) - return oldClassLoader + oldClassLoader } private def restoreClassLoader(oldClassLoader: ClassLoader) { @@ -255,7 +255,7 @@ private[spark] class MesosSchedulerBackend( .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder().setValue(1).build()) .build() - return MesosTaskInfo.newBuilder() + MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setExecutor(createExecutorInfo(slaveId)) @@ -340,5 +340,5 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.get("spark.default.parallelism", "8").toInt + override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index a24a3b04b8..c14cd47556 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} */ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { private val bufferSize = { - conf.get("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 } def newKryoOutput() = new KryoOutput(bufferSize) @@ -48,7 +48,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(conf.get("spark.kryo.referenceTracking", "true").toBoolean) + kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true)) for (cls <- KryoSerializer.toRegister) kryo.register(cls) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 47478631a1..4fa2ab96d9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -327,7 +327,7 @@ object BlockFetcherIterator { fetchRequestsSync.put(request) } - copiers = startCopiers(conf.get("spark.shuffle.copier.threads", "6").toInt) + copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7156d855d8..301d784b35 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -17,12 +17,14 @@ package org.apache.spark.storage +import java.util.UUID + /** * Identifies a particular Block of data, usually associated with a single file. * A Block can be uniquely identified by its filename, but each type of Block has a different * set of keys which produce its unique name. * - * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method. */ private[spark] sealed abstract class BlockId { /** A globally unique identifier for this Block. Can be used for ser/de. */ @@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def name = "broadcast_" + broadcastId } -private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { +private[spark] +case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { def name = broadcastId.name + "_" + hType } @@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B def name = "input-" + streamId + "-" + uniqueId } +/** Id associated with temporary data managed as blocks. Not serializable. */ +private[spark] case class TempBlockId(id: UUID) extends BlockId { + def name = "temp_" + id +} + // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6d2cda97b0..6f1345c57a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -58,8 +58,8 @@ private[spark] class BlockManager( // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { - val useNetty = conf.get("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = conf.get("spark.shuffle.sender.port", "0").toInt + val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } @@ -72,19 +72,17 @@ private[spark] class BlockManager( // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = - conf.get("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = conf.get("spark.broadcast.compress", "true").toBoolean + val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored - val compressShuffle = conf.get("spark.shuffle.compress", "true").toBoolean + val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) // Whether to compress RDD partitions that are stored serialized - val compressRdds = conf.get("spark.rdd.compress", "false").toBoolean + val compressRdds = conf.getBoolean("spark.rdd.compress", false) val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val hostPort = Utils.localHostPort(conf) - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -159,7 +157,7 @@ private[spark] class BlockManager( /** * Reregister with the master and report all blocks to it. This will be called by the heart beat - * thread if our heartbeat to the block amnager indicates that we were not registered. + * thread if our heartbeat to the block manager indicates that we were not registered. * * Note that this method must be called without any BlockInfo locks held. */ @@ -412,7 +410,7 @@ private[spark] class BlockManager( logDebug("The value of block " + blockId + " is null") } logDebug("Block " + blockId + " not found") - return None + None } /** @@ -443,7 +441,7 @@ private[spark] class BlockManager( : BlockFetcherIterator = { val iter = - if (conf.get("spark.shuffle.use.netty", "false").toBoolean) { + if (conf.getBoolean("spark.shuffle.use.netty", false)) { new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) @@ -469,7 +467,7 @@ private[spark] class BlockManager( def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - val syncWrites = conf.get("spark.shuffle.sync", "false").toBoolean + val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) } @@ -864,15 +862,15 @@ private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.get("spark.storage.memoryFraction", "0.66").toDouble + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) (Runtime.getRuntime.maxMemory * memoryFraction).toLong } def getHeartBeatFrequency(conf: SparkConf): Long = - conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 + conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.get("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 51a29ed8ef..c54e4f2664 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -30,8 +30,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.get("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = conf.get("spark.akka.retry.wait", "3000").toInt + val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) + val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 58452d9657..2c1a4e2f5d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -348,14 +348,19 @@ object BlockManagerMasterActor { if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + // But the memSize here indicates the data size in or dropped from memory, + // and the diskSize here indicates the data size in or dropped to disk. + // They can be both larger than 0, when a block is dropped from memory to disk. + // Therefore, a safe way to set BlockStatus is to set its info in accurate modes. if (storageLevel.useMemory) { + _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0)) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize)) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 21f003609b..42f52d7b26 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -42,15 +42,15 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) logDebug("Parsed as a block message array") val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - return Some(new BlockMessageArray(responseMessages).toBufferMessage) + Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { case e: Exception => logError("Exception handling buffer message", e) - return None + None } } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - return None + None } } } @@ -61,7 +61,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) - return None + None } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) @@ -70,9 +70,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends if (buffer == null) { return None } - return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) + Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) } - case _ => return None + case _ => None } } @@ -93,7 +93,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - return buffer + buffer } } @@ -111,7 +111,7 @@ private[spark] object BlockManagerWorker extends Logging { val blockMessageArray = new BlockMessageArray(blockMessage) val resultMessage = connectionManager.sendMessageReliablySync( toConnManagerId, blockMessageArray.toBufferMessage) - return (resultMessage != None) + resultMessage != None } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -130,8 +130,8 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received"); return null + case None => logDebug("No response message received") } - return null + null } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index 80dcb5a207..fbafcf79d2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -154,7 +154,7 @@ private[spark] class BlockMessage() { println() */ val finishTime = System.currentTimeMillis - return Message.createBufferMessage(buffers) + Message.createBufferMessage(buffers) } override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index a06f50a0ac..59329361f3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -96,7 +96,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM println() println() */ - return Message.createBufferMessage(buffers) + Message.createBufferMessage(buffers) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 61e63c60d5..369a277232 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -181,4 +181,8 @@ class DiskBlockObjectWriter( // Only valid if called after close() override def timeWriting() = _timeWriting + + def bytesWritten: Long = { + lastValidPosition - initialPosition + } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 55dcb3742c..a8ef7fa8b6 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.File import java.text.SimpleDateFormat -import java.util.{Date, Random} +import java.util.{Date, Random, UUID} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.get("spark.diskStore.subDirectories", "64").toInt + private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Produces a unique block id and File suitable for intermediate results. */ + def createTempBlock(): (TempBlockId, File) = { + var blockId = new TempBlockId(UUID.randomUUID()) + while (getFile(blockId).exists()) { + blockId = new TempBlockId(UUID.randomUUID()) + } + (blockId, getFile(blockId)) + } + private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 05f676c6e2..27f057b9f2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -245,7 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } } - return true + true } override def contains(blockId: BlockId): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 39dc7bb19a..e2b24298a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -64,9 +64,9 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - conf.get("spark.shuffle.consolidateFiles", "false").toBoolean + conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.get("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index b5596dffd3..0f84810d6b 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -74,7 +74,7 @@ class StorageLevel private( if (deserialized_) { ret |= 1 } - return ret + ret } override def writeExternal(out: ObjectOutput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b7b87250b9..bcd2824450 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler._ */ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { // How many stages to remember - val RETAINED_STAGES = sc.conf.get("spark.ui.retained_stages", "1000").toInt + val RETAINED_STAGES = sc.conf.getInt("spark.ui.retainedStages", 1000) val DEFAULT_POOL_NAME = "default" val stageIdToPool = new HashMap[Int, String]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8dcfeacb60..d1e58016be 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -171,7 +171,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ - <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++ + <h4>Aggregated Metrics by Executor</h4> ++ executorTable.toNodeSeq() ++ <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 463d85dfd5..9ad6de3c6d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr {if (isFairScheduler) {<th>Pool Name</th>} else {}} <th>Description</th> <th>Submitted</th> - <th>Task Time</th> + <th>Duration</th> <th>Tasks: Succeeded/Total</th> <th>Shuffle Read</th> <th>Shuffle Write</th> diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 3f009a8998..761d378c7f 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -44,13 +44,13 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, conf: SparkConf): (ActorSystem, Int) = { - val akkaThreads = conf.get("spark.akka.threads", "4").toInt - val akkaBatchSize = conf.get("spark.akka.batchSize", "15").toInt + val akkaThreads = conf.getInt("spark.akka.threads", 4) + val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.get("spark.akka.timeout", "100").toInt + val akkaTimeout = conf.getInt("spark.akka.timeout", 100) - val akkaFrameSize = conf.get("spark.akka.frameSize", "10").toInt - val akkaLogLifecycleEvents = conf.get("spark.akka.logLifecycleEvents", "false").toBoolean + val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10) + val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" if (!akkaLogLifecycleEvents) { // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent. @@ -58,12 +58,12 @@ private[spark] object AkkaUtils { Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL)) } - val logAkkaConfig = if (conf.get("spark.akka.logAkkaConfig", "false").toBoolean) "on" else "off" + val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.get("spark.akka.heartbeat.pauses", "600").toInt + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) val akkaFailureDetector = - conf.get("spark.akka.failure-detector.threshold", "300.0").toDouble - val akkaHeartBeatInterval = conf.get("spark.akka.heartbeat.interval", "1000").toInt + conf.getDouble("spark.akka.failure-detector.threshold", 300.0) + val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( @@ -103,7 +103,7 @@ private[spark] object AkkaUtils { /** Returns the default Spark timeout to use for Akka ask operations. */ def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.get("spark.akka.askTimeout", "30").toLong, "seconds") + Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds") } /** Returns the default Spark timeout to use for Akka remote actor lookup. */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 7108595e3e..1df6b87fb0 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -61,7 +61,7 @@ private[spark] object ClosureCleaner extends Logging { return f.getType :: Nil // Stop at the first $outer that is not a closure } } - return Nil + Nil } // Get a list of the outer objects for a given closure object. @@ -74,7 +74,7 @@ private[spark] object ClosureCleaner extends Logging { return f.get(obj) :: Nil // Stop at the first $outer that is not a closure } } - return Nil + Nil } private def getInnerClasses(obj: AnyRef): List[Class[_]] = { @@ -174,7 +174,7 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(obj, outer) } - return obj + obj } } } @@ -182,7 +182,7 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { + new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -215,7 +215,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { + new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index aa7f52cafb..ac07a55cb9 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -74,7 +74,7 @@ object MetadataCleanerType extends Enumeration { // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. object MetadataCleaner { def getDelaySeconds(conf: SparkConf) = { - conf.get("spark.cleaner.ttl", "3500").toInt + conf.getInt("spark.cleaner.ttl", -1) } def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index bddb3bb735..3cf94892e9 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -108,7 +108,7 @@ private[spark] object SizeEstimator extends Logging { val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) // TODO: We could use reflection on the VMOption returned ? - return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") + getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB @@ -141,7 +141,7 @@ private[spark] object SizeEstimator extends Logging { def dequeue(): AnyRef = { val elem = stack.last stack.trimEnd(1) - return elem + elem } } @@ -162,7 +162,7 @@ private[spark] object SizeEstimator extends Logging { while (!state.isFinished) { visitSingleObject(state.dequeue(), state) } - return state.size + state.size } private def visitSingleObject(obj: AnyRef, state: SearchState) { @@ -276,11 +276,11 @@ private[spark] object SizeEstimator extends Logging { // Create and cache a new ClassInfo val newInfo = new ClassInfo(shellSize, pointerFields) classInfos.put(cls, newInfo) - return newInfo + newInfo } private def alignSize(size: Long): Long = { val rem = size % ALIGN_SIZE - return if (rem == 0) size else (size + ALIGN_SIZE - rem) + if (rem == 0) size else (size + ALIGN_SIZE - rem) } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 181ae2fd45..8e07a0f29a 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -26,16 +26,23 @@ import org.apache.spark.Logging /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. + * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be + * updated when it is accessed */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { +class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends Map[A, B]() with Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { val value = internalMap.get(key) - if (value != null) Some(value._1) else None + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, (value._1, currentTime)) + } + Option(value).map(_._1) } def iterator: Iterator[(A, B)] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5f1253100b..caa9bf4c92 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -26,37 +26,61 @@ import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} +import org.apache.hadoop.io._ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging} +import org.apache.spark.{SparkConf, SparkException, Logging} /** * Various utility methods used by Spark. */ private[spark] object Utils extends Logging { + + /** + * We try to clone for most common types of writables and we call WritableUtils.clone otherwise + * intention is to optimize, for example for NullWritable there is no need and for Long, int and + * String creating a new object with value set would be faster. + */ + def cloneWritables[T: ClassTag](conf: Configuration): Writable => T = { + val cloneFunc = classTag[T] match { + case ClassTag(_: Text) => + (w: Writable) => new Text(w.asInstanceOf[Text].getBytes).asInstanceOf[T] + case ClassTag(_: LongWritable) => + (w: Writable) => new LongWritable(w.asInstanceOf[LongWritable].get).asInstanceOf[T] + case ClassTag(_: IntWritable) => + (w: Writable) => new IntWritable(w.asInstanceOf[IntWritable].get).asInstanceOf[T] + case ClassTag(_: NullWritable) => + (w: Writable) => w.asInstanceOf[T] // TODO: should we clone this ? + case _ => + (w: Writable) => WritableUtils.clone(w, conf).asInstanceOf[T] // slower way of cloning. + } + cloneFunc + } + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val oos = new ObjectOutputStream(bos) oos.writeObject(o) oos.close() - return bos.toByteArray + bos.toByteArray } /** Deserialize an object using Java serialization */ def deserialize[T](bytes: Array[Byte]): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) - return ois.readObject.asInstanceOf[T] + ois.readObject.asInstanceOf[T] } /** Deserialize an object using Java serialization and the given ClassLoader */ @@ -66,7 +90,7 @@ private[spark] object Utils extends Logging { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, loader) } - return ois.readObject.asInstanceOf[T] + ois.readObject.asInstanceOf[T] } /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */ @@ -144,7 +168,7 @@ private[spark] object Utils extends Logging { i += 1 } } - return buf + buf } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() @@ -396,15 +420,6 @@ private[spark] object Utils extends Logging { InetAddress.getByName(address).getHostName } - def localHostPort(conf: SparkConf): String = { - val retval = conf.get("spark.hostPort", null) - if (retval == null) { - logErrorWithStack("spark.hostPort not set but invoking localHostPort") - return localHostName() - } - retval - } - def checkHost(host: String, message: String = "") { assert(host.indexOf(':') == -1, message) } @@ -413,14 +428,6 @@ private[spark] object Utils extends Logging { assert(hostPort.indexOf(':') != -1, message) } - def logErrorWithStack(msg: String) { - try { - throw new Exception - } catch { - case ex: Exception => logError(msg, ex) - } - } - // Typically, this will be of order of number of nodes in cluster // If not, we should change it to LRUCache or something. private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() @@ -428,7 +435,7 @@ private[spark] object Utils extends Logging { def parseHostPort(hostPort: String): (String, Int) = { { // Check cache first. - var cached = hostPortParseResults.get(hostPort) + val cached = hostPortParseResults.get(hostPort) if (cached != null) return cached } @@ -731,7 +738,7 @@ private[spark] object Utils extends Logging { } catch { case ise: IllegalStateException => return true } - return false + false } def isSpace(c: Char): Boolean = { @@ -748,7 +755,7 @@ private[spark] object Utils extends Logging { var inWord = false var inSingleQuote = false var inDoubleQuote = false - var curWord = new StringBuilder + val curWord = new StringBuilder def endWord() { buf += curWord.toString curWord.clear() @@ -794,7 +801,7 @@ private[spark] object Utils extends Logging { if (inWord || inDoubleQuote || inSingleQuote) { endWord() } - return buf + buf } /* Calculates 'x' modulo 'mod', takes to consideration sign of x, @@ -822,8 +829,7 @@ private[spark] object Utils extends Logging { /** Returns a copy of the system properties that is thread-safe to iterator over. */ def getSystemProperties(): Map[String, String] = { - return System.getProperties().clone() - .asInstanceOf[java.util.Properties].toMap[String, String] + System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String] } /** diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index fe710c58ac..fcdf848637 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.util.Random + class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length @@ -25,7 +27,7 @@ class Vector(val elements: Array[Double]) extends Serializable { def + (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) + other(i)) + Vector(length, i => this(i) + other(i)) } def add(other: Vector) = this + other @@ -33,7 +35,7 @@ class Vector(val elements: Array[Double]) extends Serializable { def - (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) - other(i)) + Vector(length, i => this(i) - other(i)) } def subtract(other: Vector) = this - other @@ -47,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += this(i) * other(i) i += 1 } - return ans + ans } /** @@ -67,7 +69,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += (this(i) + plus(i)) * other(i) i += 1 } - return ans + ans } def += (other: Vector): Vector = { @@ -102,7 +104,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += (this(i) - other(i)) * (this(i) - other(i)) i += 1 } - return ans + ans } def dist(other: Vector): Double = math.sqrt(squaredDist(other)) @@ -117,13 +119,19 @@ object Vector { def apply(length: Int, initializer: Int => Double): Vector = { val elements: Array[Double] = Array.tabulate(length)(initializer) - return new Vector(elements) + new Vector(elements) } def zeros(length: Int) = new Vector(new Array[Double](length)) def ones(length: Int) = Vector(length, _ => 1) + /** + * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers + * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided. + */ + def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble()) + class Multiplier(num: Double) { def * (vec: Vector) = vec * num } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala index e9907e6c85..08b31ac64f 100644 --- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -91,4 +91,4 @@ private[spark] object XORShiftRandom { } -}
\ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 8bb4ee3bfa..b8c852b4ff 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection + +import java.util.{Arrays, Comparator} /** * A simple open hash table optimized for the append-only use case, where keys @@ -28,14 +30,15 @@ package org.apache.spark.util * TODO: Cache the hash values of each key? java.util.HashMap does that. */ private[spark] -class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, + V)] with Serializable { require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 - private var growThreshold = LOAD_FACTOR * capacity + private var growThreshold = (LOAD_FACTOR * capacity).toInt // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var haveNullValue = false private var nullValue: V = null.asInstanceOf[V] + // Triggered by destructiveSortedIterator; the underlying data array may no longer be used + private var destroyed = false + private val destructionMessage = "Map state is invalid from destructive sorting!" + private val LOAD_FACTOR = 0.7 /** Get the value for a given key */ def apply(key: K): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { return nullValue @@ -67,11 +75,12 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi i += 1 } } - return null.asInstanceOf[V] + null.asInstanceOf[V] } /** Set the value for a key */ def update(key: K, value: V): Unit = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi * for key, if any, or null otherwise. Returns the newly updated value. */ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Iterator method from Iterable */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = -1 - - /** Get the next value we should return from next(), or null if we're finished iterating */ - def nextValue(): (K, V) = { - if (pos == -1) { // Treat position -1 as looking at the null value - if (haveNullValue) { - return (null.asInstanceOf[K], nullValue) + override def iterator: Iterator[(K, V)] = { + assert(!destroyed, destructionMessage) + new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 } - pos += 1 - } - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 } - pos += 1 + null } - null - } - override def hasNext: Boolean = nextValue() != null + override def hasNext: Boolean = nextValue() != null - override def next(): (K, V) = { - val value = nextValue() - if (value == null) { - throw new NoSuchElementException("End of iterator") + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value } - pos += 1 - value } } @@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Double the table's size and re-hash everything */ - private def growTable() { + protected def growTable() { val newCapacity = capacity * 2 if (newCapacity >= (1 << 30)) { // We can't make the table this big because we want an array of 2x @@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi data = newData capacity = newCapacity mask = newMask - growThreshold = LOAD_FACTOR * newCapacity + growThreshold = (LOAD_FACTOR * newCapacity).toInt } private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) if (highBit == n) n else highBit << 1 } + + /** + * Return an iterator of the map in sorted order. This provides a way to sort the map without + * using additional memory, at the expense of destroying the validity of the map. + */ + def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { + destroyed = true + // Pack KV pairs into the front of the underlying array + var keyIndex, newIndex = 0 + while (keyIndex < capacity) { + if (data(2 * keyIndex) != null) { + data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) + newIndex += 1 + } + keyIndex += 1 + } + assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) + + // Sort by the given ordering + val rawOrdering = new Comparator[AnyRef] { + def compare(x: AnyRef, y: AnyRef): Int = { + cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) + } + } + Arrays.sort(data, 0, newIndex, rawOrdering) + + new Iterator[(K, V)] { + var i = 0 + var nullValueReady = haveNullValue + def hasNext: Boolean = (i < newIndex || nullValueReady) + def next(): (K, V) = { + if (nullValueReady) { + nullValueReady = false + (null.asInstanceOf[K], nullValue) + } else { + val item = data(i).asInstanceOf[(K, V)] + i += 1 + item + } + } + } + } + + /** + * Return whether the next insert will cause the map to grow + */ + def atGrowThreshold: Boolean = curSize == growThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala new file mode 100644 index 0000000000..e3bcd895aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter} + +/** + * An append-only map that spills sorted content to disk when there is insufficient space for it + * to grow. + * + * This map takes two passes over the data: + * + * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary + * (2) Combiners are read from disk and merged together + * + * The setting of the spill threshold faces the following trade-off: If the spill threshold is + * too high, the in-memory map may occupy more memory than is available, resulting in OOM. + * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk + * writes. This may lead to a performance regression compared to the normal case of using the + * non-spilling AppendOnlyMap. + * + * Two parameters control the memory threshold: + * + * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing + * these maps as a fraction of the executor's total memory. Since each concurrently running + * task maintains one map, the actual threshold for each map is this quantity divided by the + * number of running tasks. + * + * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of + * this threshold, in case map size estimation is not sufficiently accurate. + */ + +private[spark] class ExternalAppendOnlyMap[K, V, C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer = SparkEnv.get.serializerManager.default, + diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager) + extends Iterable[(K, C)] with Serializable with Logging { + + import ExternalAppendOnlyMap._ + + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + private val spilledMaps = new ArrayBuffer[DiskMapIterator] + private val sparkConf = SparkEnv.get.conf + + // Collective memory threshold shared across all running tasks + private val maxMemoryThreshold = { + val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3) + val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + // Number of pairs in the in-memory map + private var numPairsInMemory = 0 + + // Number of in-memory pairs inserted before tracking the map's shuffle memory usage + private val trackMemoryThreshold = 1000 + + // How many times we have spilled so far + private var spillCount = 0 + + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false) + private val comparator = new KCComparator[K, C] + private val ser = serializer.newInstance() + + /** + * Insert the given key and value into the map. + * + * If the underlying map is about to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insert(key: K, value: V) { + val update: (Boolean, C) => C = (hadVal, oldVal) => { + if (hadVal) mergeValue(oldVal, value) else createCombiner(value) + } + if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { + val mapSize = currentMap.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // this map to grow and, if possible, allocate the required amount + shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId + val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) + + // Assume map growth factor is 2x + shouldSpill = availableMemory < mapSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = mapSize * 2 + } + } + // Do not synchronize spills + if (shouldSpill) { + spill(mapSize) + } + } + currentMap.changeValue(key, update) + numPairsInMemory += 1 + } + + /** + * Sort the existing contents of the in-memory map and spill them to a temporary file on disk + */ + private def spill(mapSize: Long) { + spillCount += 1 + logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + val (blockId, file) = diskBlockManager.createTempBlock() + val writer = + new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites) + try { + val it = currentMap.destructiveSortedIterator(comparator) + while (it.hasNext) { + val kv = it.next() + writer.write(kv) + } + writer.commit() + } finally { + // Partial failures cannot be tolerated; do not revert partial writes + writer.close() + } + currentMap = new SizeTrackingAppendOnlyMap[K, C] + spilledMaps.append(new DiskMapIterator(file)) + + // Reset the amount of shuffle memory used by this map in the global pool + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap(Thread.currentThread().getId) = 0 + } + numPairsInMemory = 0 + } + + /** + * Return an iterator that merges the in-memory map with the spilled maps. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + override def iterator: Iterator[(K, C)] = { + if (spilledMaps.isEmpty) { + currentMap.iterator + } else { + new ExternalIterator() + } + } + + /** + * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps + */ + private class ExternalIterator extends Iterator[(K, C)] { + + // A fixed-size queue that maintains a buffer for each stream we are currently merging + val mergeHeap = new mutable.PriorityQueue[StreamBuffer] + + // Input streams are derived both from the in-memory map and spilled maps on disk + // The in-memory map is sorted in place, while the spilled maps are already in sorted order + val sortedMap = currentMap.destructiveSortedIterator(comparator) + val inputStreams = Seq(sortedMap) ++ spilledMaps + + inputStreams.foreach { it => + val kcPairs = getMorePairs(it) + mergeHeap.enqueue(StreamBuffer(it, kcPairs)) + } + + /** + * Fetch from the given iterator until a key of different hash is retrieved. In the + * event of key hash collisions, this ensures no pairs are hidden from being merged. + * Assume the given iterator is in sorted order. + */ + def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = { + val kcPairs = new ArrayBuffer[(K, C)] + if (it.hasNext) { + var kc = it.next() + kcPairs += kc + val minHash = kc._1.hashCode() + while (it.hasNext && kc._1.hashCode() == minHash) { + kc = it.next() + kcPairs += kc + } + } + kcPairs + } + + /** + * If the given buffer contains a value for the given key, merge that value into + * baseCombiner and remove the corresponding (K, C) pair from the buffer + */ + def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { + var i = 0 + while (i < buffer.pairs.size) { + val (k, c) = buffer.pairs(i) + if (k == key) { + buffer.pairs.remove(i) + return mergeCombiners(baseCombiner, c) + } + i += 1 + } + baseCombiner + } + + /** + * Return true if there exists an input stream that still has unvisited pairs + */ + override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty) + + /** + * Select a key with the minimum hash, then combine all values with the same key from all input streams. + */ + override def next(): (K, C) = { + // Select a key from the StreamBuffer that holds the lowest key hash + val minBuffer = mergeHeap.dequeue() + val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + if (minPairs.length == 0) { + // Should only happen when no other stream buffers have any pairs left + throw new NoSuchElementException + } + var (minKey, minCombiner) = minPairs.remove(0) + assert(minKey.hashCode() == minHash) + + // For all other streams that may have this key (i.e. have the same minimum key hash), + // merge in the corresponding value (if any) from that stream + val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) + while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) { + val newBuffer = mergeHeap.dequeue() + minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) + mergedBuffers += newBuffer + } + + // Repopulate each visited stream buffer and add it back to the merge heap + mergedBuffers.foreach { buffer => + if (buffer.pairs.length == 0) { + buffer.pairs ++= getMorePairs(buffer.iterator) + } + mergeHeap.enqueue(buffer) + } + + (minKey, minCombiner) + } + + /** + * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. + * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to + * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * + * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + */ + case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)]) + extends Comparable[StreamBuffer] { + + def minKeyHash: Int = { + if (pairs.length > 0){ + // pairs are already sorted by key hash + pairs(0)._1.hashCode() + } else { + Int.MaxValue + } + } + + override def compareTo(other: StreamBuffer): Int = { + // minus sign because mutable.PriorityQueue dequeues the max, not the min + -minKeyHash.compareTo(other.minKeyHash) + } + } + } + + /** + * An iterator that returns (K, C) pairs in sorted order from an on-disk map + */ + private class DiskMapIterator(file: File) extends Iterator[(K, C)] { + val fileStream = new FileInputStream(file) + val bufferedStream = new FastBufferedInputStream(fileStream) + val deserializeStream = ser.deserializeStream(bufferedStream) + var nextItem: (K, C) = null + var eof = false + + def readNextItem(): (K, C) = { + if (!eof) { + try { + return deserializeStream.readObject().asInstanceOf[(K, C)] + } catch { + case e: EOFException => + eof = true + cleanup() + } + } + null + } + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + } + nextItem != null + } + + override def next(): (K, C) = { + val item = if (nextItem == null) readNextItem() else nextItem + if (item == null) { + throw new NoSuchElementException + } + nextItem = null + item + } + + // TODO: Ensure this gets called even if the iterator isn't drained. + def cleanup() { + deserializeStream.close() + file.delete() + } + } +} + +private[spark] object ExternalAppendOnlyMap { + private class KCComparator[K, C] extends Comparator[(K, C)] { + def compare(kc1: (K, C), kc2: (K, C)): Int = { + kc1._1.hashCode().compareTo(kc2._1.hashCode()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala new file mode 100644 index 0000000000..204330dad4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.SizeEstimator +import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample + +/** + * Append-only map that keeps track of its estimated size in bytes. + * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, + * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds). + */ +private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] { + + /** + * Controls the base of the exponential which governs the rate of sampling. + * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. + */ + private val SAMPLE_GROWTH_RATE = 1.1 + + /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */ + private val samples = new ArrayBuffer[Sample]() + + /** Total number of insertions and updates into the map since the last resetSamples(). */ + private var numUpdates: Long = _ + + /** The value of 'numUpdates' at which we will take our next sample. */ + private var nextSampleNum: Long = _ + + /** The average number of bytes per update between our last two samples. */ + private var bytesPerUpdate: Double = _ + + resetSamples() + + /** Called after the map grows in size, as this can be a dramatic change for small objects. */ + def resetSamples() { + numUpdates = 1 + nextSampleNum = 1 + samples.clear() + takeSample() + } + + override def update(key: K, value: V): Unit = { + super.update(key, value) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + } + + override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val newValue = super.changeValue(key, updateFunc) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + newValue + } + + /** Takes a new sample of the current map's size. */ + def takeSample() { + samples += Sample(SizeEstimator.estimate(this), numUpdates) + // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change. + bytesPerUpdate = math.max(0, samples.toSeq.reverse match { + case latest :: previous :: tail => + (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates) + case _ => + 0 + }) + nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong + } + + override protected def growTable() { + super.growTable() + resetSamples() + } + + /** Estimates the current size of the map in bytes. O(1) time. */ + def estimateSize(): Long = { + assert(samples.nonEmpty) + val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates) + (samples.last.size + extrapolatedDelta).toLong + } +} + +private object SizeTrackingAppendOnlyMap { + case class Sample(size: Long, numUpdates: Long) +} diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 8dd5786da6..3ac706110e 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -53,7 +53,6 @@ object LocalSparkContext { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index afc1beff98..930c2523ca 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -99,7 +99,6 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext - System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 331fa3a642..d05bbd6ff7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -25,8 +25,8 @@ import net.liftweb.json.JsonAST.JValue import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} +import org.apache.spark.deploy.worker.{ExecutorRunner, DriverRunner} class JsonProtocolSuite extends FunSuite { test("writeApplicationInfo") { @@ -50,11 +50,13 @@ class JsonProtocolSuite extends FunSuite { } test("writeMasterState") { - val workers = Array[WorkerInfo](createWorkerInfo(), createWorkerInfo()) - val activeApps = Array[ApplicationInfo](createAppInfo()) + val workers = Array(createWorkerInfo(), createWorkerInfo()) + val activeApps = Array(createAppInfo()) val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array(createDriverInfo()) + val completedDrivers = Array(createDriverInfo()) val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, - RecoveryState.ALIVE) + activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) } @@ -62,26 +64,44 @@ class JsonProtocolSuite extends FunSuite { test("writeWorkerState") { val executors = List[ExecutorRunner]() val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) + val drivers = List(createDriverRunner()) + val finishedDrivers = List(createDriverRunner(), createDriverRunner()) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, - finishedExecutors, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") + finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) assertValidJson(output) } - def createAppDesc() : ApplicationDescription = { + def createAppDesc(): ApplicationDescription = { val cmd = new Command("mainClass", List("arg1", "arg2"), Map()) new ApplicationDescription("name", Some(4), 1234, cmd, "sparkHome", "appUiUrl") } + def createAppInfo() : ApplicationInfo = { new ApplicationInfo( 3, "id", createAppDesc(), new Date(123456789), null, "appUriStr", Int.MaxValue) } - def createWorkerInfo() : WorkerInfo = { + + def createDriverCommand() = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")) + ) + + def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, + false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") } - def createExecutorRunner() : ExecutorRunner = { + def createExecutorRunner(): ExecutorRunner = { new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), ExecutorState.RUNNING) + new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING) + } + def createDriverRunner(): DriverRunner = { + new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), + null, "akka://worker") } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala new file mode 100644 index 0000000000..45dbcaffae --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -0,0 +1,131 @@ +package org.apache.spark.deploy.worker + +import java.io.File + +import scala.collection.JavaConversions._ + +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.scalatest.FunSuite + +import org.apache.spark.deploy.{Command, DriverDescription} +import org.mockito.stubbing.Answer +import org.mockito.invocation.InvocationOnMock + +class DriverRunnerTest extends FunSuite { + private def createDriverRunner() = { + val command = new Command("mainClass", Seq(), Map()) + val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) + new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription, + null, "akka://1.2.3.4/worker/") + } + + private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { + val processBuilder = mock(classOf[ProcessBuilderLike]) + when(processBuilder.command).thenReturn(Seq("mocked", "command")) + val process = mock(classOf[Process]) + when(processBuilder.start()).thenReturn(process) + (processBuilder, process) + } + + test("Process succeeds instantly") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + // One failure then a successful run + when(process.waitFor()).thenReturn(0) + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Process failing several times and then succeeding") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + // fail, fail, fail, success + when(process.waitFor()).thenReturn(-1).thenReturn(-1).thenReturn(-1).thenReturn(0) + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(4)).waitFor() + verify(sleeper, times(3)).sleep(anyInt()) + verify(sleeper, times(1)).sleep(1) + verify(sleeper, times(1)).sleep(2) + verify(sleeper, times(1)).sleep(4) + } + + test("Process doesn't restart if not supervised") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + when(process.waitFor()).thenReturn(-1) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = false) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Process doesn't restart if killed") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + when(process.waitFor()).thenAnswer(new Answer[Int] { + def answer(invocation: InvocationOnMock): Int = { + runner.kill() + -1 + } + }) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Reset of backoff counter") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val clock = mock(classOf[Clock]) + runner.setClock(clock) + + val (processBuilder, process) = createProcessBuilderAndProcess() + + when(process.waitFor()) + .thenReturn(-1) // fail 1 + .thenReturn(-1) // fail 2 + .thenReturn(-1) // fail 3 + .thenReturn(-1) // fail 4 + .thenReturn(0) // success + when(clock.currentTimeMillis()) + .thenReturn(0).thenReturn(1000) // fail 1 (short) + .thenReturn(1000).thenReturn(2000) // fail 2 (short) + .thenReturn(2000).thenReturn(10000) // fail 3 (long) + .thenReturn(10000).thenReturn(11000) // fail 4 (short) + .thenReturn(11000).thenReturn(21000) // success (long) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(sleeper, times(4)).sleep(anyInt()) + // Expected sequence of sleeps is 1,2,1,2 + verify(sleeper, times(2)).sleep(1) + verify(sleeper, times(2)).sleep(2) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index be93074b7b..a79ee690d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -31,8 +31,8 @@ class ExecutorRunnerTest extends FunSuite { sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), - f("ooga"), ExecutorState.RUNNING) + f("ooga"), "blah", ExecutorState.RUNNING) - assert(er.buildCommandSeq().last === appId) + assert(er.getCommandSeq.last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala new file mode 100644 index 0000000000..94d88d307a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -0,0 +1,32 @@ +package org.apache.spark.deploy.worker + + +import akka.testkit.TestActorRef +import org.scalatest.FunSuite +import akka.remote.DisassociatedEvent +import akka.actor.{ActorSystem, AddressFromURIString, Props} + +class WorkerWatcherSuite extends FunSuite { + test("WorkerWatcher shuts down on valid disassociation") { + val actorSystem = ActorSystem("test") + val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) + val workerWatcher = actorRef.underlyingActor + workerWatcher.setTesting(testing = true) + actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) + assert(actorRef.underlyingActor.isShutDown) + } + + test("WorkerWatcher stays alive on invalid disassociation") { + val actorSystem = ActorSystem("test") + val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) + val workerWatcher = actorRef.underlyingActor + workerWatcher.setTesting(testing = true) + actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) + assert(!actorRef.underlyingActor.isShutDown) + } +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 7bf2020fe3..235d31709a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -64,7 +64,7 @@ class FakeTaskSetManager( } override def getSchedulableByName(name: String): Schedulable = { - return null + null } override def executorLost(executorId: String, host: String): Unit = { @@ -79,13 +79,14 @@ class FakeTaskSetManager( { if (tasksSuccessful + runningTasks < numTasks) { increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } else { + None } - return None } override def checkSpeculatableTasks(): Boolean = { - return true + true } def taskFinished() { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2aa259daf3..f0236ef1e9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -122,7 +122,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont locations: Seq[Seq[String]] = Nil ): MyRDD = { val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { + val newRDD = new MyRDD(sc, dependencies) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") override def getPartitions = (0 to maxPartition).map(i => new Partition { @@ -135,6 +135,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } + newRDD } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 5cc48ee00a..29102913c7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -42,12 +42,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) } type MyRDD = RDD[(Int, Int)] - def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]] - ): MyRDD = { + def makeRdd(numPartitions: Int, dependencies: List[Dependency[_]]): MyRDD = { val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { + new MyRDD(sc, dependencies) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") override def getPartitions = (0 to maxPartition).map(i => new Partition { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1eec6726f4..c9f6cc5d07 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -83,7 +83,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { private val conf = new SparkConf - val LOCALITY_WAIT = conf.get("spark.locality.wait", "3000").toLong + val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f60ce270c7..18aa587662 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,6 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) - conf.set("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) @@ -65,13 +64,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT conf.set("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - // Set some value ... - conf.set("spark.hostPort", Utils.localHostName() + ":" + 1111) } after { System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") if (store != null) { store.stop() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 0ed366fb70..de4871d043 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -61,8 +61,8 @@ class NonSerializable {} object TestObject { def run(): Int = { var nonSer = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => + val x = 5 + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + x).reduce(_ + _) } @@ -76,7 +76,7 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + getX).reduce(_ + _) } @@ -88,7 +88,7 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + getX).reduce(_ + _) } @@ -103,7 +103,7 @@ class TestClassWithoutFieldAccess { def run(): Int = { var nonSer2 = new NonSerializable var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + x).reduce(_ + _) } @@ -115,7 +115,7 @@ object TestObjectWithNesting { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) var y = 1 for (i <- 1 to 4) { @@ -134,7 +134,7 @@ class TestClassWithNesting(val y: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) for (i <- 1 to 4) { var nonSer2 = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala new file mode 100644 index 0000000000..93f0c6a8e6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass +import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap} + +class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll { + val NORMAL_ERROR = 0.20 + val HIGH_ERROR = 0.30 + + test("fixed size insertions") { + testWith[Int, Long](10000, i => (i, i.toLong)) + testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) + testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass())) + } + + test("variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[Int, String](10000, i => (i, randString(0, 10))) + testWith[Int, String](10000, i => (i, randString(0, 100))) + testWith[Int, String](10000, i => (i, randString(90, 100))) + } + + test("updates") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[String, Int](10000, i => (randString(0, 10000), i)) + } + + def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) { + val map = new SizeTrackingAppendOnlyMap[K, V]() + for (i <- 0 until numElements) { + val (k, v) = makeElement(i) + map(k) = v + expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { + val betterEstimatedSize = SizeEstimator.estimate(obj) + assert(betterEstimatedSize * (1 - error) < estimatedSize, + s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") + assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, + s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") + } +} + +object SizeTrackingAppendOnlyMapSuite { + // Speed test, for reproducibility of results. + // These could be highly non-deterministic in general, however. + // Results: + // AppendOnlyMap: 31 ms + // SizeTracker: 54 ms + // SizeEstimator: 1500 ms + def main(args: Array[String]) { + val numElements = 100000 + + val baseTimes = for (i <- 0 until 10) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + } + } + + val sampledTimes = for (i <- 0 until 10) yield time { + val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + map.estimateSize() + } + } + + val unsampledTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + SizeEstimator.estimate(map) + } + } + + println("Base: " + baseTimes) + println("SizeTracker (sampled): " + sampledTimes) + println("SizeEstimator (unsampled): " + unsampledTimes) + } + + def time(f: => Unit): Long = { + val start = System.currentTimeMillis() + f + System.currentTimeMillis() - start + } + + private class LargeDummyClass { + val arr = new Array[Int](100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala new file mode 100644 index 0000000000..7006571ef0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.FunSuite + +/** + * Tests org.apache.spark.util.Vector functionality + */ +class VectorSuite extends FunSuite { + + def verifyVector(vector: Vector, expectedLength: Int) = { + assert(vector.length == expectedLength) + assert(vector.elements.min > 0.0) + assert(vector.elements.max < 1.0) + } + + test("random with default random number generator") { + val vector100 = Vector.random(100) + verifyVector(vector100, 100) + } + + test("random with given random number generator") { + val vector100 = Vector.random(100, new Random(100)) + verifyVector(vector100, 100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala index b78367b6ca..f1d7b61b31 100644 --- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -73,4 +73,4 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { } -}
\ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index 7177919a58..f44442f1a5 100644 --- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import java.util.Comparator class AppendOnlyMapSuite extends FunSuite { test("initialization") { @@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite { assert(map("" + i) === "" + i) } } + + test("destructive sort") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + map.update(null, "happy new year!") + + try { + map.apply("1") + map.update("1", "2013") + map.changeValue("1", (hadValue, oldValue) => "2014") + map.iterator + } catch { + case e: IllegalStateException => fail() + } + + val it = map.destructiveSortedIterator(new Comparator[(String, String)] { + def compare(kv1: (String, String), kv2: (String, String)): Int = { + val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue + val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue + x.compareTo(y) + } + }) + + // Should be sorted by key + assert(it.hasNext) + var previous = it.next() + assert(previous == (null, "happy new year!")) + previous = it.next() + assert(previous == ("1", "2014")) + while (it.hasNext) { + val kv = it.next() + assert(kv._1.toInt > previous._1.toInt) + previous = kv + } + + // All subsequent calls to apply, update, changeValue and iterator should throw exception + intercept[AssertionError] { map.apply("1") } + intercept[AssertionError] { map.update("1", "2013") } + intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") } + intercept[AssertionError] { map.iterator } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala new file mode 100644 index 0000000000..ef957bb0e5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -0,0 +1,230 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + override def beforeEach() { + val conf = new SparkConf(false) + conf.set("spark.shuffle.externalSorting", "true") + sc = new SparkContext("local", "test", conf) + } + + val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i) + val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => { + buffer += i + } + val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] = + (buf1, buf2) => { + buf1 ++= buf2 + } + + test("simple insert") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + // Single insert + map.insert(1, 10) + var it = map.iterator + assert(it.hasNext) + val kv = it.next() + assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10)) + assert(!it.hasNext) + + // Multiple insert + map.insert(2, 20) + map.insert(3, 30) + it = map.iterator + assert(it.hasNext) + assert(it.toSet == Set[(Int, ArrayBuffer[Int])]( + (1, ArrayBuffer[Int](10)), + (2, ArrayBuffer[Int](20)), + (3, ArrayBuffer[Int](30)))) + } + + test("insert with collision") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + map.insert(1, 10) + map.insert(2, 20) + map.insert(3, 30) + map.insert(1, 100) + map.insert(2, 200) + map.insert(1, 1000) + val it = map.iterator + assert(it.hasNext) + val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](10, 100, 1000)), + (2, Set[Int](20, 200)), + (3, Set[Int](30)))) + } + + test("ordering") { + val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map1.insert(1, 10) + map1.insert(2, 20) + map1.insert(3, 30) + + val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map2.insert(2, 20) + map2.insert(3, 30) + map2.insert(1, 10) + + val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map3.insert(3, 30) + map3.insert(1, 10) + map3.insert(2, 20) + + val it1 = map1.iterator + val it2 = map2.iterator + val it3 = map3.iterator + + var kv1 = it1.next() + var kv2 = it2.next() + var kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + } + + test("null keys and values") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map.insert(1, 5) + map.insert(2, 6) + map.insert(3, 7) + assert(map.size === 3) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)) + )) + + // Null keys + val nullInt = null.asInstanceOf[Int] + map.insert(nullInt, 8) + assert(map.size === 4) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)), + (nullInt, Seq[Int](8)) + )) + + // Null values + map.insert(4, nullInt) + map.insert(nullInt, nullInt) + assert(map.size === 5) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](5)), + (2, Set[Int](6)), + (3, Set[Int](7)), + (4, Set[Int](nullInt)), + (nullInt, Set[Int](nullInt, 8)) + )) + } + + test("simple aggregator") { + // reduceByKey + val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) + val result1 = rdd.reduceByKey(_+_).collect() + assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5))) + + // groupByKey + val result2 = rdd.groupByKey().collect() + assert(result2.toSet == Set[(Int, Seq[Int])] + ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1)))) + } + + test("simple cogroup") { + val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) + val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) + val result = rdd1.cogroup(rdd2).collect() + + result.foreach { case (i, (seq1, seq2)) => + i match { + case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4)) + case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3)) + case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]()) + case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]()) + case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]()) + } + } + } + + test("spilling") { + // TODO: Figure out correct memory parameters to actually induce spilling + // System.setProperty("spark.shuffle.buffer.mb", "1") + // System.setProperty("spark.shuffle.buffer.fraction", "0.05") + + // reduceByKey - should spill exactly 6 times + val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max(_, _)).collect() + assert(resultA.length == 5000) + resultA.foreach { case(k, v) => + k match { + case 0 => assert(v == 1) + case 2500 => assert(v == 5001) + case 4999 => assert(v == 9999) + case _ => + } + } + + // groupByKey - should spill exactly 11 times + val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 2500) + resultB.foreach { case(i, seq) => + i match { + case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) + case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003)) + case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999)) + case _ => + } + } + + // cogroup - should spill exactly 7 times + val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 1000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900)) + case 500 => + assert(seq1.toSet == Set[Int](500)) + assert(seq2.toSet == Set[Int]()) + case 999 => + assert(seq1.toSet == Set[Int](999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + } + + // TODO: Test memory allocation for multiple concurrently running tasks +} |