diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-29 16:00:51 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-29 16:00:51 -0800 |
commit | c5cee53f2092ee2825095a1831ca47f1c41afc2f (patch) | |
tree | 29c36f14668e67493b2c8ed98cb4d4124baf841e /core | |
parent | 26186e2d259f3aa2db9c8594097fd342107ce147 (diff) | |
parent | 3f74f729a190924b7634e08a381232af36aeb328 (diff) | |
download | spark-c5cee53f2092ee2825095a1831ca47f1c41afc2f.tar.gz spark-c5cee53f2092ee2825095a1831ca47f1c41afc2f.tar.bz2 spark-c5cee53f2092ee2825095a1831ca47f1c41afc2f.zip |
Merge remote-tracking branch 'origin/master' into python-api
Conflicts:
docs/quick-start.md
Diffstat (limited to 'core')
97 files changed, 2892 insertions, 1531 deletions
diff --git a/core/pom.xml b/core/pom.xml new file mode 100644 index 0000000000..ae52c20657 --- /dev/null +++ b/core/pom.xml @@ -0,0 +1,270 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.spark-project</groupId> + <artifactId>parent</artifactId> + <version>0.7.0-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <packaging>jar</packaging> + <name>Spark Project Core</name> + <url>http://spark-project.org/</url> + + <dependencies> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-server</artifactId> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + <dependency> + <groupId>com.ning</groupId> + <artifactId>compress-lzf</artifactId> + </dependency> + <dependency> + <groupId>asm</groupId> + <artifactId>asm-all</artifactId> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </dependency> + <dependency> + <groupId>de.javakaffee</groupId> + <artifactId>kryo-serializers</artifactId> + </dependency> + <dependency> + <groupId>com.typesafe.akka</groupId> + <artifactId>akka-actor</artifactId> + </dependency> + <dependency> + <groupId>com.typesafe.akka</groupId> + <artifactId>akka-remote</artifactId> + </dependency> + <dependency> + <groupId>com.typesafe.akka</groupId> + <artifactId>akka-slf4j</artifactId> + </dependency> + <dependency> + <groupId>it.unimi.dsi</groupId> + <artifactId>fastutil</artifactId> + </dependency> + <dependency> + <groupId>colt</groupId> + <artifactId>colt</artifactId> + </dependency> + <dependency> + <groupId>cc.spray</groupId> + <artifactId>spray-can</artifactId> + </dependency> + <dependency> + <groupId>cc.spray</groupId> + <artifactId>spray-server</artifactId> + </dependency> + <dependency> + <groupId>org.tomdz.twirl</groupId> + <artifactId>twirl-api</artifactId> + </dependency> + <dependency> + <groupId>com.github.scala-incubator.io</groupId> + <artifactId>scala-io-file_${scala.version}</artifactId> + </dependency> + <dependency> + <groupId>org.apache.mesos</groupId> + <artifactId>mesos</artifactId> + </dependency> + + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalacheck</groupId> + <artifactId>scalacheck_${scala.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.novocode</groupId> + <artifactId>junit-interface</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-log4j12</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-antrun-plugin</artifactId> + <executions> + <execution> + <phase>test</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <exportAntProperties>true</exportAntProperties> + <tasks> + <property name="spark.classpath" refid="maven.test.classpath"/> + <property environment="env"/> + <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry."> + <condition> + <not> + <or> + <isset property="env.SCALA_HOME"/> + <isset property="env.SCALA_LIBRARY_PATH"/> + </or> + </not> + </condition> + </fail> + </tasks> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <environmentVariables> + <SPARK_HOME>${basedir}/..</SPARK_HOME> + <SPARK_TESTING>1</SPARK_TESTING> + <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH> + </environmentVariables> + </configuration> + </plugin> + <plugin> + <groupId>org.tomdz.twirl</groupId> + <artifactId>twirl-maven-plugin</artifactId> + </plugin> + </plugins> + </build> + + <profiles> + <profile> + <id>hadoop1</id> + <dependencies> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-core</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-source</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>src/hadoop1/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop1</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>hadoop2</id> + <dependencies> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-core</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-source</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>src/hadoop2/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> + </profiles> +</project>
\ No newline at end of file diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000..ca9f7219de --- /dev/null +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,7 @@ +package org.apache.hadoop.mapred + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) +} diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000..de7b0f81e3 --- /dev/null +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,9 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000..35300cea58 --- /dev/null +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,7 @@ +package org.apache.hadoop.mapred + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000..7afdbff320 --- /dev/null +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,10 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import task.{TaskAttemptContextImpl, JobContextImpl} + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) +} diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala index b0daa70cfd..df8ce9c054 100644 --- a/core/src/main/scala/spark/Aggregator.scala +++ b/core/src/main/scala/spark/Aggregator.scala @@ -1,17 +1,44 @@ package spark +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConversions._ + /** A set of functions used to aggregate data. * * @param createCombiner function to create the initial value of the aggregation. * @param mergeValue function to merge a new value into the aggregation result. * @param mergeCombiners function to merge outputs from multiple mergeValue function. - * @param mapSideCombine whether to apply combiners on map partitions, also - * known as map-side aggregations. When set to false, - * mergeCombiners function is not used. */ case class Aggregator[K, V, C] ( val createCombiner: V => C, val mergeValue: (C, V) => C, - val mergeCombiners: (C, C) => C, - val mapSideCombine: Boolean = true) + val mergeCombiners: (C, C) => C) { + + def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = { + val combiners = new JHashMap[K, C] + for ((k, v) <- iter) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, createCombiner(v)) + } else { + combiners.put(k, mergeValue(oldC, v)) + } + } + combiners.iterator + } + + def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { + val combiners = new JHashMap[K, C] + for ((k, c) <- iter) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, mergeCombiners(oldC, c)) + } + } + combiners.iterator + } +} diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 4554db2249..86432d0127 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -1,18 +1,12 @@ package spark -import java.io.EOFException -import java.net.URL - import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.storage.BlockException import spark.storage.BlockManagerId -import it.unimi.dsi.fastutil.io.FastBufferedInputStream - private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { + override def fetch[K, V](shuffleId: Int, reduceId: Int) = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -31,14 +25,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) } - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = { + val blockId = blockPair._1 + val blockOption = blockPair._2 blockOption match { case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } + block.asInstanceOf[Iterator[(K, V)]] } case None => { val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r @@ -53,8 +45,6 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } } } - - logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) + blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock) } } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index c5db6ce63a..3d79078733 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -1,5 +1,9 @@ package spark +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + import akka.actor._ import akka.dispatch._ import akka.pattern.ask @@ -8,10 +12,6 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - import spark.storage.BlockManager import spark.storage.StorageLevel @@ -41,7 +41,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - + def receive = { case SlaveCacheStarted(host: String, size: Long) => slaveCapacity.put(host, size) @@ -92,14 +92,14 @@ private[spark] class CacheTrackerActor extends Actor with Logging { private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) extends Logging { - + // Tracker actor on the master, or remote reference to it on workers val ip: String = System.getProperty("spark.master.host", "localhost") val port: Int = System.getProperty("spark.master.port", "7077").toInt val actorName: String = "CacheTracker" val timeout = 10.seconds - + var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) logInfo("Registered CacheTrackerActor actor") @@ -132,7 +132,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b throw new SparkException("Error reply received from CacheTracker") } } - + // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { registeredRddIds.synchronized { @@ -143,7 +143,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b } } } - + // For BlockManager.scala only def cacheLost(host: String) { communicate(MemoryCacheLost(host)) @@ -155,19 +155,20 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b def getCacheStatus(): Seq[(String, Long, Long)] = { askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] } - + // For BlockManager.scala only def notifyFromBlockManager(t: AddedToCache) { communicate(t) } - + // Get a snapshot of the currently known locations def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] } - + // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { @@ -209,7 +210,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split) + elements ++= rdd.compute(split, context) try { // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index dfc7e292b7..b85d2732db 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -22,12 +22,10 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param shuffleId the shuffle id * @param rdd the parent RDD - * @param aggregator optional aggregator; this allows for map-side combining * @param partitioner partitioner used to partition the shuffle output */ -class ShuffleDependency[K, V, C]( +class ShuffleDependency[K, V]( @transient rdd: RDD[(K, V)], - val aggregator: Option[Aggregator[K, V, C]], val partitioner: Partitioner) extends Dependency(rdd) { diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index ffe0f3c4a1..afcf9f6db4 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -23,7 +23,7 @@ import spark.SerializableWritable * Saves the RDD using a JobConf, which should contain an output key class, an output value class, * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ -class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializable { +class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable { private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -129,14 +129,14 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl private def getJobContext(): JobContext = { if (jobContext == null) { - jobContext = new JobContext(conf.value, jID.value) + jobContext = newJobContext(conf.value, jID.value) } return jobContext } private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = new TaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } return taskContext } diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 44b630e478..93d7327324 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,153 +9,80 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} -import com.esotericsoftware.kryo.serialize.ClassSerializer -import com.esotericsoftware.kryo.serialize.SerializableSerializer +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import serializer.{SerializerInstance, DeserializationStream, SerializationStream} import spark.broadcast._ import spark.storage._ -/** - * Zig-zag encoder used to write object sizes to serialization streams. - * Based on Kryo's integer encoder. - */ -private[spark] object ZigZag { - def writeInt(n: Int, out: OutputStream) { - var value = n - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - out.write(value) - } +private[spark] +class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - def readInt(in: InputStream): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = in.read() - if (b == -1) { - throw new EOFException("End of stream") - } - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new SparkException("Malformed zigzag-encoded integer") - } -} - -private[spark] -class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) -extends SerializationStream { - val channel = Channels.newChannel(out) + val output = new KryoOutput(outStream) def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(threadBuffer, t) - ZigZag.writeInt(threadBuffer.position(), out) - threadBuffer.flip() - channel.write(threadBuffer) - threadBuffer.clear() + kryo.writeClassAndObject(output, t) this } - def flush() { out.flush() } - def close() { out.close() } + def flush() { output.flush() } + def close() { output.close() } } -private[spark] -class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) -extends DeserializationStream { +private[spark] +class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { + + val input = new KryoInput(inStream) + def readObject[T](): T = { - val len = ZigZag.readInt(in) - objectBuffer.readClassAndObject(in, len).asInstanceOf[T] + try { + kryo.readClassAndObject(input).asInstanceOf[T] + } catch { + // DeserializationStream uses the EOF exception to indicate stopping condition. + case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException + } } - def close() { in.close() } + def close() { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.kryo - val threadBuffer = ks.threadBuffer.get() - val objectBuffer = ks.objectBuffer.get() + + val kryo = ks.kryo.get() + val output = ks.output.get() + val input = ks.input.get() def serialize[T](t: T): ByteBuffer = { - // Write it to our thread-local scratch buffer first to figure out the size, then return a new - // ByteBuffer of the appropriate size - threadBuffer.clear() - kryo.writeClassAndObject(threadBuffer, t) - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf + output.clear() + kryo.writeClassAndObject(output, t) + ByteBuffer.wrap(output.toBytes) } def deserialize[T](bytes: ByteBuffer): T = { - kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] } def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) - val obj = kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + val obj = kryo.readClassAndObject(input).asInstanceOf[T] kryo.setClassLoader(oldClassLoader) obj } def serializeStream(s: OutputStream): SerializationStream = { - threadBuffer.clear() - new KryoSerializationStream(kryo, threadBuffer, s) + new KryoSerializationStream(kryo, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(objectBuffer, s) - } - - override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - threadBuffer.clear() - while (iterator.hasNext) { - val element = iterator.next() - // TODO: Do we also want to write the object's size? Doesn't seem necessary. - kryo.writeClassAndObject(threadBuffer, element) - } - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf - } - - override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - buffer.rewind() - new Iterator[Any] { - override def hasNext: Boolean = buffer.remaining > 0 - override def next(): Any = kryo.readClassAndObject(buffer) - } + new KryoDeserializationStream(kryo, s) } } @@ -171,18 +98,19 @@ trait KryoRegistrator { * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. */ class KryoSerializer extends spark.serializer.Serializer with Logging { - // Make this lazy so that it only gets called once we receive our first task on each executor, - // so we can pull out any custom Kryo registrator from the user's JARs. - lazy val kryo = createKryo() - val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 + val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - val objectBuffer = new ThreadLocal[ObjectBuffer] { - override def initialValue = new ObjectBuffer(kryo, bufferSize) + val kryo = new ThreadLocal[Kryo] { + override def initialValue = createKryo() } - val threadBuffer = new ThreadLocal[ByteBuffer] { - override def initialValue = ByteBuffer.allocate(bufferSize) + val output = new ThreadLocal[KryoOutput] { + override def initialValue = new KryoOutput(bufferSize) + } + + val input = new ThreadLocal[KryoInput] { + override def initialValue = new KryoInput(bufferSize) } def createKryo(): Kryo = { @@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo.register(obj.getClass) } - // Register the following classes for passing closures. - kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) - kryo.setRegistrationOptional(true) - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) + kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. - class SingletonSerializer(obj: AnyRef) extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {} - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T] + class SingletonSerializer[T](obj: T) extends KSerializer[T] { + override def write(kryo: Kryo, output: KryoOutput, obj: T) {} + override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj } - kryo.register(None.getClass, new SingletonSerializer(None)) - kryo.register(Nil.getClass, new SingletonSerializer(Nil)) + kryo.register(None.getClass, new SingletonSerializer[AnyRef](None)) + kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil)) // Register maps with a special serializer since they have complex internal structure class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any]) - extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) { + extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] { + override def write( + kryo: Kryo, + output: KryoOutput, + obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) { val map = obj.asInstanceOf[scala.collection.Map[Any, Any]] - kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer]) + kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer]) for ((k, v) <- map) { - kryo.writeClassAndObject(buf, k) - kryo.writeClassAndObject(buf, v) + kryo.writeClassAndObject(output, k) + kryo.writeClassAndObject(output, v) } } - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = { - val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue + override def read ( + kryo: Kryo, + input: KryoInput, + cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]]) + : Array[(Any, Any)] => scala.collection.Map[Any, Any] = { + val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue val elems = new Array[(Any, Any)](size) for (i <- 0 until size) - elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf)) - buildMap(elems).asInstanceOf[T] + elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input)) + buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]] } } kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _)) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 45441aa5e5..70eb9f702e 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -2,6 +2,10 @@ package spark import java.io._ import java.util.concurrent.ConcurrentHashMap +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import akka.actor._ import akka.dispatch._ @@ -11,16 +15,13 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import scheduler.MapStatus +import spark.scheduler.MapStatus import spark.storage.BlockManagerId -import java.util.zip.{GZIPInputStream, GZIPOutputStream} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) - extends MapOutputTrackerMessage + extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { @@ -88,14 +89,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } - + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { var array = mapStatuses.get(shuffleId) array.synchronized { array(mapId) = status } } - + def registerMapOutputs( shuffleId: Int, statuses: Array[MapStatus], @@ -110,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var array = mapStatuses.get(shuffleId) if (array != null) { array.synchronized { - if (array(mapId).address == bmAddress) { + if (array(mapId) != null && array(mapId).address == bmAddress) { array(mapId) = null } } @@ -119,10 +120,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } - + // Remembers which map output locations are currently being fetched on a worker val fetching = new HashSet[Int] - + // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId) @@ -147,14 +148,23 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) val host = System.getProperty("spark.hostname", Utils.localHostName) - val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] - val fetchedStatuses = deserializeStatuses(fetchedBytes) - - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + // This try-finally prevents hangs due to timeouts: + var fetchedStatuses: Array[MapStatus] = null + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + if (fetchedStatuses.contains(null)) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } } return fetchedStatuses.map(s => (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) @@ -254,8 +264,10 @@ private[spark] object MapOutputTracker { * sizes up to 35 GB with at most 10% error. */ def compressSize(size: Long): Byte = { - if (size <= 1L) { + if (size == 0) { 0 + } else if (size <= 1L) { + 1 } else { math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte } @@ -266,7 +278,7 @@ private[spark] object MapOutputTracker { */ def decompressSize(compressedSize: Byte): Long = { if (compressedSize == 0) { - 1 + 0 } else { math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 0240fd95c7..d3e206b353 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,10 +1,6 @@ package spark -import java.io.EOFException -import java.io.ObjectInputStream -import java.net.URL import java.util.{Date, HashMap => JHashMap} -import java.util.concurrent.atomic.AtomicLong import java.text.SimpleDateFormat import scala.collection.Map @@ -14,25 +10,14 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.FileOutputCommitter import org.apache.hadoop.mapred.FileOutputFormat import org.apache.hadoop.mapred.HadoopWriter import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCommitter import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} -import org.apache.hadoop.mapreduce.TaskAttemptID -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext} import spark.partial.BoundedDouble import spark.partial.PartialResult @@ -46,14 +31,15 @@ import spark.SparkContext._ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self: RDD[(K, V)]) extends Logging + with HadoopMapReduceUtil with Serializable { /** - * Generic function to combine the elements for each key using a custom set of aggregation + * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: - * + * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. @@ -67,15 +53,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner, mapSideCombine: Boolean = true): RDD[(K, C)] = { val aggregator = - if (mapSideCombine) { - new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - } else { - // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) - new Aggregator[K, V, C](createCombiner, mergeValue, null, false) - } - new ShuffledAggregatedRDD(self, aggregator, partitioner) + new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + if (mapSideCombine) { + val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) + val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) + partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) + } else { + // Don't apply map-side combiner. + // A sanity check to make sure mergeCombiners is not defined. + assert(mergeCombiners == null) + val values = new ShuffledRDD[K, V](self, partitioner) + values.mapPartitions(aggregator.combineValuesByKey(_), true) + } } /** @@ -129,7 +118,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ @@ -184,7 +173,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( createCombiner _, mergeValue _, mergeCombiners _, partitioner) bufs.flatMapValues(buf => buf) } else { - new RepartitionShuffledRDD(self, partitioner) + new ShuffledRDD[K, V](self, partitioner) } } @@ -235,7 +224,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } - /** + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the default * parallelism level. */ @@ -449,7 +438,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner") + self.filter(_._1 == key).map(_._2).collect } } @@ -506,7 +495,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /* "reduce task" <split #> <attempt # = spark task #> */ val attemptId = new TaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) - val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) @@ -525,7 +514,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * setupJob/commitJob, so we just use a dummy "map" task. */ val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) - val jobTaskContext = new TaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) val count = self.context.runJob(self, writeShard _).sum @@ -621,7 +610,16 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( * order of the keys). */ def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = { - new ShuffledSortedRDD(self, ascending, numSplits) + val shuffled = + new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending)) + shuffled.mapPartitions(iter => { + val buf = iter.toArray + if (ascending) { + buf.sortWith((x, y) => x._1 < y._1).iterator + } else { + buf.sortWith((x, y) => x._1 > y._1).iterator + } + }, true) } } @@ -630,7 +628,8 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)] override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) override val partitioner = prev.partitioner - override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} + override def compute(split: Split, taskContext: TaskContext) = + prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))} } private[spark] @@ -641,8 +640,8 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U] override val dependencies = List(new OneToOneDependency(prev)) override val partitioner = prev.partitioner - override def compute(split: Split) = { - prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } + override def compute(split: Split, taskContext: TaskContext) = { + prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9b57ae3b4f..a27f766e31 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -8,8 +8,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( val slice: Int, values: Seq[T]) extends Split with Serializable { - - def iterator(): Iterator[T] = values.iterator + + def iterator: Iterator[T] = values.iterator override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt @@ -22,7 +22,7 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + sc: SparkContext, @transient data: Seq[T], numSlices: Int) extends RDD[T](sc) { @@ -38,17 +38,18 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def splits = splits_.asInstanceOf[Array[Split]] - override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - + override def compute(s: Split, taskContext: TaskContext) = + s.asInstanceOf[ParallelCollectionSplit[T]].iterator + override def preferredLocations(s: Split): Seq[String] = Nil - + override val dependencies: List[Dependency[_]] = Nil } private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range - * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes + * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. */ def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { @@ -58,7 +59,7 @@ private object ParallelCollection { seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) { - -1 + -1 } else { 1 } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ddb420efff..d15c6f7396 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,17 +1,17 @@ package spark import java.io.EOFException -import java.net.URL import java.io.ObjectInputStream -import java.util.concurrent.atomic.AtomicLong +import java.net.URL import java.util.Random import java.util.Date import java.util.{HashMap => JHashMap} +import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.ArrayBuffer import scala.collection.Map -import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -42,12 +42,13 @@ import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD import spark.rdd.UnionRDD +import spark.rdd.ZippedRDD import spark.storage.StorageLevel import SparkContext._ /** - * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, + * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, * partitioned collection of elements that can be operated on in parallel. This class contains the * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such @@ -80,34 +81,34 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def splits: Array[Split] /** Function for computing a given partition. */ - def compute(split: Split): Iterator[T] + def compute(split: Split, context: TaskContext): Iterator[T] /** How this RDD depends on any parent RDDs. */ @transient val dependencies: List[Dependency[_]] // Methods available on all RDDs: - + /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite - + /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ def preferredLocations(split: Split): Seq[String] = Nil - + /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - + /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - + // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -123,47 +124,47 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - + } + // This is a hack. Ideally this should re-use the code used by the CacheTracker // to generate the key. def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - + persist(level) sc.runJob(this, (iter: Iterator[T]) => {} ) - + val p = this.partitioner - + new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + override val partitioner = p } } - + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - final def iterator(split: Split): Iterator[T] = { + final def iterator(split: Split, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) + SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) } else { - compute(split) + compute(split, context) } } - + // Transformations (return a new RDD) - + /** * Return a new RDD by applying a function to all elements of this RDD. */ @@ -184,9 +185,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int = splits.size): RDD[T] = + def distinct(numSplits: Int): RDD[T] = map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1) + def distinct(): RDD[T] = distinct(splits.size) + /** * Return a sampled subset of this RDD. */ @@ -199,13 +202,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var multiplier = 3.0 var initialCount = count() var maxSelected = 0 - + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { maxSelected = initialCount.toInt } - + if (num > initialCount) { total = maxSelected fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0) @@ -215,14 +218,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial fraction = math.min(multiplier * (num + 1) / initialCount, 1.0) total = num } - + val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt).collect() - + while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt).collect() } - + Utils.randomizeInPlace(samples, rand).take(total) } @@ -282,15 +285,26 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = - new MapPartitionsRDD(this, sc.clean(f)) + def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = - new MapPartitionsWithSplitRDD(this, sc.clean(f)) + def mapPartitionsWithSplit[U: ClassManifest]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) // Actions (launch a job to return a value to the user program) @@ -341,7 +355,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t2. */ @@ -442,7 +456,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } - + /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * it will be slow if a lot of partitions are required. In that case, use collect() to get the diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index daa35fe7f2..d9a94d4021 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -1,10 +1,12 @@ package spark private[spark] abstract class ShuffleFetcher { - // Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly - // once on each key-value pair obtained. - def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) + /** + * Fetch the shuffle outputs for a given ShuffleDependency. + * @return An iterator over the elements of the fetched shuffle outputs. + */ + def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)] - // Stop the fetcher + /** Stop the fetcher */ def stop() {} } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index becf737597..4fd81bc63b 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -45,7 +45,6 @@ import spark.scheduler.TaskScheduler import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import spark.storage.BlockManagerMaster /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -87,7 +86,7 @@ class SparkContext( // Set Spark master host and port system properties if (System.getProperty("spark.master.host") == null) { - System.setProperty("spark.master.host", Utils.localIpAddress()) + System.setProperty("spark.master.host", Utils.localIpAddress) } if (System.getProperty("spark.master.port") == null) { System.setProperty("spark.master.port", "0") @@ -174,10 +173,11 @@ class SparkContext( MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, master, jobName) + new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName) } else { - new MesosSchedulerBackend(scheduler, this, master, jobName) + new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName) } scheduler.initialize(backend) scheduler @@ -199,7 +199,7 @@ class SparkContext( parallelize(seq, numSlices) } - /** + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. */ @@ -400,7 +400,7 @@ class SparkContext( new Accumulable(initialValue, param) } - /** + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ @@ -419,19 +419,29 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case the task is executed locally - val filename = new File(path.split("/").last) + // Fetch the file locally in case a job is executed locally. + // Jobs that run through LocalScheduler will already fetch the required dependencies, + // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. Utils.fetchFile(path, new File(".")) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } /** - * Clear the job's list of files added by `addFile` so that they do not get donwloaded to + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getSlavesMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.ip + ":" + blockManagerId.port, mem) + } + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. */ def clearFiles() { - addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedFiles.clear() } @@ -455,7 +465,6 @@ class SparkContext( * any new nodes. */ def clearJars() { - addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedJars.clear() } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 4c6ec6cc6e..41441720a7 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -68,7 +68,6 @@ object SparkEnv extends Logging { isMaster: Boolean, isLocal: Boolean ) : SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), @@ -87,10 +86,13 @@ object SparkEnv extends Logging { } val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - - val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) - val blockManager = new BlockManager(blockManagerMaster, serializer) - + + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster( + actorSystem, isMaster, isLocal, masterIp, masterPort) + val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) + val connectionManager = blockManager.connectionManager val broadcastManager = new BroadcastManager(isMaster) @@ -105,7 +107,7 @@ object SparkEnv extends Logging { val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - + val httpFileServer = new HttpFileServer() httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index c14377d17b..d2746b26b3 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -1,3 +1,20 @@ package spark -class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable +import scala.collection.mutable.ArrayBuffer + + +class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { + + @transient + val onCompleteCallbacks = new ArrayBuffer[() => Unit] + + // Add a callback function to be executed on task completion. An example use + // is for HadoopRDD to register a callback to close the input stream. + def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += f + } + + def executeOnCompleteCallbacks() { + onCompleteCallbacks.foreach{_()} + } +} diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 567c4b1475..0e7007459d 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,13 +1,15 @@ package spark import java.io._ -import java.net.{InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import scala.io.Source +import com.google.common.io.Files /** * Various utility methods used by Spark. @@ -126,31 +128,53 @@ private object Utils extends Logging { /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * + * Throws SparkException if the target file already exists and has different contents than + * the requested file. */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last + val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) uri.getScheme match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + targetFile) + logInfo("Fetching " + url + " to " + tempFile) val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } case "file" | null => - // Remove the file if it already exists - targetFile.delete() - // Symlink the file locally. - if (uri.isAbsolute) { - // url is absolute, i.e. it starts with "file:///". Extract the source - // file's absolute path from the url. - val sourceFile = new File(uri) - logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + val sourceFile = if (uri.isAbsolute) { + new File(uri) } else { - // url is not absolute, i.e. itself is the path to the source file. - logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(url, targetFile.getAbsolutePath) + new File(url) + } + if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally. + if (uri.isAbsolute) { + // url is absolute, i.e. it starts with "file:///". Extract the source + // file's absolute path from the url. + val sourceFile = new File(uri) + logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + } else { + // url is not absolute, i.e. itself is the path to the source file. + logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(url, targetFile.getAbsolutePath) + } } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others @@ -158,8 +182,15 @@ private object Utils extends Logging { val conf = new Configuration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } } // Decompress the file if it's a .tar or .tar.gz if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { @@ -199,7 +230,35 @@ private object Utils extends Logging { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). */ - def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress + lazy val localIpAddress: String = findLocalIpAddress() + + private def findLocalIpAddress(): String = { + val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") + if (defaultIpOverride != null) { + defaultIpOverride + } else { + val address = InetAddress.getLocalHost + if (address.isLoopbackAddress) { + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + for (ni <- NetworkInterface.getNetworkInterfaces) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + // We've found an address that looks reasonable! + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + + " instead (on interface " + ni.getName + ")") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + return addr.getHostAddress + } + } + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + + " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + + " external IP address!") + logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") + } + address.getHostAddress + } + } private var customHostname: Option[String] = None diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 13fcee1004..81d3a94466 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,16 +1,15 @@ package spark.api.java -import spark.{SparkContext, Split, RDD} +import java.util.{List => JList} +import scala.Tuple2 +import scala.collection.JavaConversions._ + +import spark.{SparkContext, Split, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel -import java.util.{List => JList} - -import scala.collection.JavaConversions._ -import java.{util, lang} -import scala.Tuple2 trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This @@ -24,7 +23,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** The [[spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context - + /** A unique ID for this RDD (within its SparkContext). */ def id: Int = rdd.id @@ -36,7 +35,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split)) + def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] = + asJavaIterator(rdd.iterator(split, taskContext)) // Transformations (return a new RDD) @@ -99,7 +99,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) } - /** * Return a new RDD by applying a function to each partition of this RDD. */ @@ -172,8 +171,18 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { + JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) + } + // Actions (launch a job to return a value to the user program) - + /** * Applies a function f to all elements of this RDD. */ @@ -190,7 +199,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { val arr: java.util.Collection[T] = rdd.collect().toSeq new java.util.ArrayList(arr) } - + /** * Reduces the elements of this RDD using the specified associative binary operator. */ @@ -198,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t2. */ @@ -241,7 +250,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2))))) + mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index edbb187b1b..b7725313c4 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -301,6 +301,40 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * (in that order of preference). If neither of these is set, return None. */ def getSparkHome(): Option[String] = sc.getSparkHome() + + /** + * Add a file to be downloaded into the working directory of this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addFile(path: String) { + sc.addFile(path) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + sc.addJar(path) + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + sc.clearJars() + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to + * any new nodes. + */ + def clearFiles() { + sc.clearFiles() + } } object JavaSparkContext { diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index ef27bbb502..386f505f2a 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -48,7 +48,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Used only in Workers @transient var ttGuide: TalkToGuide = null - @transient var hostAddress = Utils.localIpAddress() + @transient var hostAddress = Utils.localIpAddress @transient var listenPort = -1 @transient var guidePort = -1 diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index fa676e9064..f573512835 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -36,7 +36,7 @@ extends Broadcast[T](id) with Logging with Serializable { @transient var serveMR: ServeMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null - @transient var hostAddress = Utils.localIpAddress() + @transient var hostAddress = Utils.localIpAddress @transient var listenPort = -1 @transient var guidePort = -1 @@ -138,7 +138,7 @@ extends Broadcast[T](id) with Logging with Serializable { serveMR = null - hostAddress = Utils.localIpAddress() + hostAddress = Utils.localIpAddress listenPort = -1 stopBroadcast = false diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index d2b63d6e0d..457122745b 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -11,8 +11,15 @@ private[spark] sealed trait DeployMessage extends Serializable // Worker to Master -private[spark] -case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int) +private[spark] +case class RegisterWorker( + id: String, + host: String, + port: Int, + cores: Int, + memory: Int, + webUiPort: Int, + publicAddress: String) extends DeployMessage private[spark] @@ -20,7 +27,8 @@ case class ExecutorStateChanged( jobId: String, execId: Int, state: ExecutorState, - message: Option[String]) + message: Option[String], + exitStatus: Option[Int]) extends DeployMessage // Master to Worker @@ -51,7 +59,8 @@ private[spark] case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) private[spark] -case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String]) +case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], + exitStatus: Option[Int]) private[spark] case class JobKilled(message: String) @@ -67,8 +76,8 @@ private[spark] case object RequestMasterState // Master to MasterWebUI private[spark] -case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo], - completedJobs: List[JobInfo]) +case class MasterState(uri: String, workers: Array[WorkerInfo], activeJobs: Array[JobInfo], + completedJobs: Array[JobInfo]) // WorkerWebUI to Worker private[spark] case object RequestWorkerState @@ -78,4 +87,4 @@ private[spark] case object RequestWorkerState private[spark] case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, - coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
\ No newline at end of file + coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 8b2a71add5..4211d80596 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -35,11 +35,15 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) /* Start the Slaves */ for (slaveNum <- 1 to numSlaves) { + /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. + All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is + sufficiently distinctive. */ + val slaveIpAddress = "127.100.0." + (slaveNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) + AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) slaveActorSystems += actorSystem val actor = actorSystem.actorOf( - Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), name = "Worker") slaveActors += actor } diff --git a/core/src/main/scala/spark/deploy/WebUI.scala b/core/src/main/scala/spark/deploy/WebUI.scala new file mode 100644 index 0000000000..ad1a1092b2 --- /dev/null +++ b/core/src/main/scala/spark/deploy/WebUI.scala @@ -0,0 +1,30 @@ +package spark.deploy + +import java.text.SimpleDateFormat +import java.util.Date + +/** + * Utilities used throughout the web UI. + */ +private[spark] object WebUI { + val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + + def formatDate(date: Date): String = DATE_FORMAT.format(date) + + def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp)) + + def formatDuration(milliseconds: Long): String = { + val seconds = milliseconds.toDouble / 1000 + if (seconds < 60) { + return "%.0f s".format(seconds) + } + val minutes = seconds / 60 + if (minutes < 10) { + return "%.1f min".format(minutes) + } else if (minutes < 60) { + return "%.0f min".format(minutes) + } + val hours = minutes / 60 + return "%.1f h".format(hours) + } +} diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index e51b0c5c15..90fe9508cd 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -35,6 +35,7 @@ private[spark] class Client( class ClientActor extends Actor with Logging { var master: ActorRef = null + var masterAddress: Address = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times override def preStart() { @@ -43,6 +44,7 @@ private[spark] class Client( val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) try { master = context.actorFor(akkaUrl) + masterAddress = master.path.address master ! RegisterJob(jobDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing @@ -64,15 +66,25 @@ private[spark] class Client( logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) listener.executorAdded(fullId, workerId, host, cores, memory) - case ExecutorUpdated(id, state, message) => + case ExecutorUpdated(id, state, message, exitStatus) => val fullId = jobId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { - listener.executorRemoved(fullId, message.getOrElse("")) + listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + case Terminated(actor_) if actor_ == master => + logError("Connection to master failed; stopping client") + markDisconnected() + context.stop(self) + + case RemoteClientDisconnected(transport, address) if address == masterAddress => + logError("Connection to master failed; stopping client") + markDisconnected() + context.stop(self) + + case RemoteClientShutdown(transport, address) if address == masterAddress => logError("Connection to master failed; stopping client") markDisconnected() context.stop(self) diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index a8fa982085..da6abcc9c2 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -14,5 +14,5 @@ private[spark] trait ClientListener { def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit - def executorRemoved(id: String, message: String): Unit + def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index bf0e7428ba..57a7e123b7 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -18,12 +18,12 @@ private[spark] object TestClient { def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {} - def executorRemoved(id: String, message: String) {} + def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} } def main(args: Array[String]) { val url = args(0) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress(), 0) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) val listener = new TestListener diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala index 8795c09cc1..130b031a2a 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala @@ -5,11 +5,17 @@ import java.util.Date import akka.actor.ActorRef import scala.collection.mutable -private[spark] -class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) { +private[spark] class JobInfo( + val startTime: Long, + val id: String, + val desc: JobDescription, + val submitDate: Date, + val actor: ActorRef) +{ var state = JobState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] var coresGranted = 0 + var endTime = -1L private var nextExecutorId = 0 @@ -41,4 +47,17 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va _retryCount += 1 _retryCount } + + def markFinished(endState: JobState.Value) { + state = endState + endTime = System.currentTimeMillis() + } + + def duration: Long = { + if (endTime != -1) { + endTime - startTime + } else { + System.currentTimeMillis() - startTime + } + } } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 6010f7cff2..6ecebe626a 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -31,6 +31,16 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val waitingJobs = new ArrayBuffer[JobInfo] val completedJobs = new ArrayBuffer[JobInfo] + val masterPublicAddress = { + val envVar = System.getenv("SPARK_PUBLIC_DNS") + if (envVar != null) envVar else ip + } + + // As a temporary workaround before better ways of configuring memory, we allow users to set + // a flag that will perform round-robin scheduling across the nodes (spreading out each job + // among all the nodes) instead of trying to consolidate each job onto a small # of nodes. + val spreadOutJobs = System.getProperty("spark.deploy.spreadOut", "false").toBoolean + override def preStart() { logInfo("Starting Spark master at spark://" + ip + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -50,15 +60,15 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } override def receive = { - case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort) => { + case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( host, workerPort, cores, Utils.memoryMegabytesToString(memory))) if (idToWorker.contains(id)) { sender ! RegisterWorkerFailed("Duplicate worker ID") } else { - addWorker(id, host, workerPort, cores, memory, worker_webUiPort) + addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredWorker("http://" + ip + ":" + webUiPort) + sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUiPort) schedule() } } @@ -73,12 +83,12 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor schedule() } - case ExecutorStateChanged(jobId, execId, state, message) => { + case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => { val execOption = idToJob.get(jobId).flatMap(job => job.executors.get(execId)) execOption match { case Some(exec) => { exec.state = state - exec.job.actor ! ExecutorUpdated(execId, state, message) + exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { val jobInfo = idToJob(jobId) // Remove this executor from the worker and job @@ -123,28 +133,63 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } case RequestMasterState => { - sender ! MasterState(ip + ":" + port, workers.toList, jobs.toList, completedJobs.toList) + sender ! MasterState(ip + ":" + port, workers.toArray, jobs.toArray, completedJobs.toArray) } } /** + * Can a job use the given worker? True if the worker has enough memory and we haven't already + * launched an executor for the job on it (right now the standalone backend doesn't like having + * two executors on the same worker). + */ + def canUse(job: JobInfo, worker: WorkerInfo): Boolean = { + worker.memoryFree >= job.desc.memoryPerSlave && !worker.hasExecutor(job) + } + + /** * Schedule the currently available resources among waiting jobs. This method will be called * every time a new job joins or resource availability changes. */ def schedule() { - // Right now this is a very simple FIFO scheduler. We keep looking through the jobs - // in order of submission time and launching the first one that fits on each node. - for (worker <- workers if worker.coresFree > 0) { - for (job <- waitingJobs.clone()) { - val jobMemory = job.desc.memoryPerSlave - if (worker.memoryFree >= jobMemory) { - val coresToUse = math.min(worker.coresFree, job.coresLeft) - val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first job + // in the queue, then the second job, etc. + if (spreadOutJobs) { + // Try to spread out each job among all the nodes, until it has all its cores + for (job <- waitingJobs if job.coresLeft > 0) { + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(canUse(job, _)).sortBy(_.coresFree).reverse + val numUsable = usableWorkers.length + val assigned = new Array[Int](numUsable) // Number of cores to give on each node + var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum) + var pos = 0 + while (toAssign > 0) { + if (usableWorkers(pos).coresFree - assigned(pos) > 0) { + toAssign -= 1 + assigned(pos) += 1 + } + pos = (pos + 1) % numUsable } - if (job.coresLeft == 0) { - waitingJobs -= job - job.state = JobState.RUNNING + // Now that we've decided how many cores to give on each node, let's actually give them + for (pos <- 0 until numUsable) { + if (assigned(pos) > 0) { + val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) + launchExecutor(usableWorkers(pos), exec) + job.state = JobState.RUNNING + } + } + } + } else { + // Pack each job into as few nodes as possible until we've assigned all its cores + for (worker <- workers if worker.coresFree > 0) { + for (job <- waitingJobs if job.coresLeft > 0) { + if (canUse(job, worker)) { + val coresToUse = math.min(worker.coresFree, job.coresLeft) + if (coresToUse > 0) { + val exec = job.addExecutor(worker, coresToUse) + launchExecutor(worker, exec) + job.state = JobState.RUNNING + } + } } } } @@ -157,8 +202,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } - def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int): WorkerInfo = { - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort) + def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, + publicAddress: String): WorkerInfo = { + // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. + workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) + val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker actorToWorker(sender) = worker @@ -168,19 +216,20 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) - workers -= worker + worker.setState(WorkerState.DEAD) idToWorker -= worker.id actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address for (exec <- worker.executors.values) { - exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None) + exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) exec.job.executors -= exec.id } } def addJob(desc: JobDescription, actor: ActorRef): JobInfo = { - val date = new Date - val job = new JobInfo(newJobId(date), desc, date, actor) + val now = System.currentTimeMillis() + val date = new Date(now) + val job = new JobInfo(now, newJobId(date), desc, date, actor) jobs += job idToJob(job.id) = job actorToJob(sender) = job @@ -189,19 +238,21 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } def removeJob(job: JobInfo) { - logInfo("Removing job " + job.id) - jobs -= job - idToJob -= job.id - actorToJob -= job.actor - addressToWorker -= job.actor.path.address - completedJobs += job // Remember it in our history - waitingJobs -= job - for (exec <- job.executors.values) { - exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(exec.job.id, exec.id) + if (jobs.contains(job)) { + logInfo("Removing job " + job.id) + jobs -= job + idToJob -= job.id + actorToJob -= job.actor + addressToWorker -= job.actor.path.address + completedJobs += job // Remember it in our history + waitingJobs -= job + for (exec <- job.executors.values) { + exec.worker.removeExecutor(exec) + exec.worker.actor ! KillExecutor(exec.job.id, exec.id) + } + job.markFinished(JobState.FINISHED) // TODO: Mark it as FAILED if it failed + schedule() } - job.state = JobState.FINISHED - schedule() } /** Generate a new job ID given a job's submission date */ diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 1b1c3dd0ad..4ceab3fc03 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -7,7 +7,7 @@ import spark.Utils * Command-line parser for the master. */ private[spark] class MasterArguments(args: Array[String]) { - var ip = Utils.localIpAddress() + var ip = Utils.localHostName() var port = 7077 var webUiPort = 8080 @@ -59,4 +59,4 @@ private[spark] class MasterArguments(args: Array[String]) { " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) } -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 700a41c770..3cdd3721f5 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -36,7 +36,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct // A bit ugly an inefficient, but we won't have a number of jobs // so large that it will make a significant difference. - (masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match { + (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match { case Some(job) => spark.deploy.master.html.job_details.render(job) case _ => null } diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 16b3f9b653..5a7f5fef8a 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -10,10 +10,11 @@ private[spark] class WorkerInfo( val cores: Int, val memory: Int, val actor: ActorRef, - val webUiPort: Int) { + val webUiPort: Int, + val publicAddress: String) { var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - + var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 var memoryUsed = 0 @@ -33,8 +34,16 @@ private[spark] class WorkerInfo( memoryUsed -= exec.memory } } - + + def hasExecutor(job: JobInfo): Boolean = { + executors.values.exists(_.job == job) + } + def webUiAddress : String = { - "http://" + this.host + ":" + this.webUiPort + "http://" + this.publicAddress + ":" + this.webUiPort + } + + def setState(state: WorkerState.Value) = { + this.state = state } } diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala new file mode 100644 index 0000000000..0bf35014c8 --- /dev/null +++ b/core/src/main/scala/spark/deploy/master/WorkerState.scala @@ -0,0 +1,7 @@ +package spark.deploy.master + +private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { + type WorkerState = Value + + val ALIVE, DEAD, DECOMMISSIONED = Value +} diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 07ae7bca78..beceb55ecd 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -60,7 +60,7 @@ private[spark] class ExecutorRunner( process.destroy() process.waitFor() } - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None) + worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None, None) Runtime.getRuntime.removeShutdownHook(shutdownHook) } } @@ -134,7 +134,8 @@ private[spark] class ExecutorRunner( // times on the same machine. val exitCode = process.waitFor() val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message)) + worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), + Some(exitCode)) } catch { case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") @@ -145,7 +146,7 @@ private[spark] class ExecutorRunner( process.destroy() } val message = e.getClass + ": " + e.getMessage - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message)) + worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), None) } } } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 474c9364fd..7c9e588ea2 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -36,6 +36,10 @@ private[spark] class Worker( var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] + val publicAddress = { + val envVar = System.getenv("SPARK_PUBLIC_DNS") + if (envVar != null) envVar else ip + } var coresUsed = 0 var memoryUsed = 0 @@ -79,7 +83,7 @@ private[spark] class Worker( val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) try { master = context.actorFor(akkaUrl) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort) + master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -123,10 +127,10 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(jobId, execId, ExecutorState.LOADING, None) + master ! ExecutorStateChanged(jobId, execId, ExecutorState.RUNNING, None, None) - case ExecutorStateChanged(jobId, execId, state, message) => - master ! ExecutorStateChanged(jobId, execId, state, message) + case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => + master ! ExecutorStateChanged(jobId, execId, state, message, exitStatus) val fullId = jobId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 60dc107a4c..340920025b 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory * Command-line parser for the master. */ private[spark] class WorkerArguments(args: Array[String]) { - var ip = Utils.localIpAddress() + var ip = Utils.localHostName() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() @@ -110,4 +110,4 @@ private[spark] class WorkerArguments(args: Array[String]) { // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index dfdb22024e..2552958d27 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -43,6 +43,26 @@ private[spark] class Executor extends Logging { urlClassLoader = createClassLoader() Thread.currentThread.setContextClassLoader(urlClassLoader) + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler( + new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, exception: Throwable) { + try { + logError("Uncaught exception in thread " + thread, exception) + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } + } catch { + case oom: OutOfMemoryError => System.exit(ExecutorExitCode.OOM) + case t: Throwable => System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) + } + } + } + ) + // Initialize Spark environment (using system properties read above) env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) SparkEnv.set(env) diff --git a/core/src/main/scala/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/spark/executor/ExecutorExitCode.scala new file mode 100644 index 0000000000..fd76029cb3 --- /dev/null +++ b/core/src/main/scala/spark/executor/ExecutorExitCode.scala @@ -0,0 +1,43 @@ +package spark.executor + +/** + * These are exit codes that executors should use to provide the master with information about + * executor failures assuming that cluster management framework can capture the exit codes (but + * perhaps not log files). The exit code constants here are chosen to be unlikely to conflict + * with "natural" exit statuses that may be caused by the JVM or user code. In particular, + * exit codes 128+ arise on some Unix-likes as a result of signals, and it appears that the + * OpenJDK JVM may use exit code 1 in some of its own "last chance" code. + */ +private[spark] +object ExecutorExitCode { + /** The default uncaught exception handler was reached. */ + val UNCAUGHT_EXCEPTION = 50 + + /** The default uncaught exception handler was called and an exception was encountered while + logging the exception. */ + val UNCAUGHT_EXCEPTION_TWICE = 51 + + /** The default uncaught exception handler was reached, and the uncaught exception was an + OutOfMemoryError. */ + val OOM = 52 + + /** DiskStore failed to create a local temporary directory after many attempts. */ + val DISK_STORE_FAILED_TO_CREATE_DIR = 53 + + def explainExitCode(exitCode: Int): String = { + exitCode match { + case UNCAUGHT_EXCEPTION => "Uncaught exception" + case UNCAUGHT_EXCEPTION_TWICE => "Uncaught exception, and logging the exception failed" + case OOM => "OutOfMemoryError" + case DISK_STORE_FAILED_TO_CREATE_DIR => + "Failed to create local directory (bad spark.local.dir?)" + case _ => + "Unknown executor exit code (" + exitCode + ")" + ( + if (exitCode > 128) + " (died from signal " + (exitCode - 128) + "?)" + else + "" + ) + } + } +} diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index da39108164..642fa4b525 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -304,7 +304,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { connectionRequests += newConnection newConnection } - val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) + val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) + val connection = connectionsById.getOrElse(lookupKey, startNewConnection()) message.senderAddress = id.toSocketAddress() logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index cb73976aed..f98528a183 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -2,11 +2,8 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.Dependency -import spark.RDD -import spark.SparkContext -import spark.SparkEnv -import spark.Split +import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext} + private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split { val index = idx @@ -19,29 +16,29 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St @transient val splits_ = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] - }).toArray - - @transient + }).toArray + + @transient lazy val locations_ = { - val blockManager = SparkEnv.get.blockManager + val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locations = blockManager.getLocations(blockIds) + val locations = blockManager.getLocations(blockIds) HashMap(blockIds.zip(locations):_*) } override def splits = splits_ - override def compute(split: Split): Iterator[T] = { - val blockManager = SparkEnv.get.blockManager + override def compute(split: Split, context: TaskContext): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDSplit].blockId blockManager.get(blockId) match { case Some(block) => block.asInstanceOf[Iterator[T]] - case None => + case None => throw new Exception("Could not compute split, block " + blockId + " not found") } } - override def preferredLocations(split: Split) = + override def preferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) override val dependencies: List[Dependency[_]] = Nil diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 7c354b6b2e..4a7e5f3d06 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext} + private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { @@ -17,9 +15,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( rdd2: RDD[U]) extends RDD[Pair[T, U]](sc) with Serializable { - + val numSplitsInRdd2 = rdd2.splits.size - + @transient val splits_ = { // create the cross product split @@ -38,11 +36,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } - override def compute(split: Split) = { + override def compute(split: Split, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianSplit] - for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y) + for (x <- rdd1.iterator(currSplit.s1, context); + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } - + override val dependencies = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index ace2500627..de0d9fad88 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,27 +1,17 @@ package spark.rdd -import java.net.URL -import java.io.EOFException -import java.io.ObjectInputStream - import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.Aggregator -import spark.Dependency -import spark.Logging -import spark.OneToOneDependency -import spark.Partitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split +import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} +import spark.{Dependency, OneToOneDependency, ShuffleDependency} + private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] +private[spark] class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -36,24 +26,25 @@ private[spark] class CoGroupAggregator class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { - + val aggr = new CoGroupAggregator - + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { - if (rdd.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + rdd) - deps += new OneToOneDependency(rdd) + val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) + if (mapSideCombinedRDD.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD) + deps += new OneToOneDependency(mapSideCombinedRDD) } else { logInfo("Adding shuffle dependency with " + rdd) - deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](rdd, Some(aggr), part) + deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } deps.toList } - + @transient val splits_ : Array[Split] = { val firstRdd = rdds.head @@ -61,7 +52,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => dependencies(j) match { - case s: ShuffleDependency[_, _, _] => + case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep @@ -72,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def splits = splits_ - + override val partitioner = Some(part) - + override def preferredLocations(s: Split) = Nil - - override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { + + override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] @@ -87,19 +78,19 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { // Read them from the parent - for ((k, v) <- rdd.iterator(itsSplit)) { + for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v } } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle - def mergePair(k: K, vs: Seq[Any]) { - val mySeq = getSeq(k) - for (v <- vs) + def mergePair(pair: (K, Seq[Any])) { + val mySeq = getSeq(pair._1) + for (v <- pair._2) mySeq(depNum) += v } val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, Seq[Any]](shuffleId, split.index, mergePair) + fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) } } map.iterator diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 0967f4f5df..1affe0e0ef 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,8 +1,7 @@ package spark.rdd -import spark.NarrowDependency -import spark.RDD -import spark.Split +import spark.{NarrowDependency, RDD, Split, TaskContext} + private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split @@ -32,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) override def splits = splits_ - override def compute(split: Split): Iterator[T] = { + override def compute(split: Split, context: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit) + parentSplit => prev.iterator(parentSplit, context) } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index dfe9dc73f3..b148da28de 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,12 +1,11 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 3534dc8057..785662b2da 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,16 +1,16 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => TraversableOnce[U]) extends RDD[U](prev.context) { - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) + + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index e30564f2da..fac8ffb4cb 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,12 +1,12 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator + override def compute(split: Split, context: TaskContext) = + Array(prev.iterator(split, context).toArray).iterator }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index bf29a1f075..ab163f569b 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -15,19 +15,16 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import spark.Dependency -import spark.RDD -import spark.SerializableWritable -import spark.SparkContext -import spark.Split +import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} -/** + +/** * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) extends Split with Serializable { - + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -47,10 +44,10 @@ class HadoopRDD[K, V]( valueClass: Class[V], minSplits: Int) extends RDD[(K, V)](sc) { - + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - + @transient val splits_ : Array[Split] = { val inputFormat = createInputFormat(conf) @@ -69,7 +66,7 @@ class HadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split) = new Iterator[(K, V)] { + override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null @@ -77,6 +74,9 @@ class HadoopRDD[K, V]( val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => reader.close()) + val key: K = reader.createKey() val value: V = reader.createValue() var gotNext = false @@ -115,6 +115,6 @@ class HadoopRDD[K, V]( val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - + override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index b2c7a1cb9e..c764505345 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,16 +1,18 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: Iterator[T] => Iterator[U]) + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false) extends RDD[U](prev.context) { - + + override val partitioner = if (preservesPartitioning) prev.partitioner else None + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) + override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e..3d9888bd34 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,8 +1,6 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -12,10 +10,13 @@ import spark.Split private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U]) + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean) extends RDD[U](prev.context) { + override val partitioner = if (preservesPartitioning) prev.partitioner else None override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(split.index, prev.iterator(split)) + override def compute(split: Split, context: TaskContext) = + f(split.index, prev.iterator(split, context)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 59bedad8ef..70fa8f4497 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,16 +1,14 @@ package spark.rdd -import spark.OneToOneDependency -import spark.RDD -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => U) extends RDD[U](prev.context) { - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index dcbceab246..197ed5ea17 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -1,28 +1,19 @@ package spark.rdd +import java.text.SimpleDateFormat +import java.util.Date + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce.InputFormat -import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.JobContext -import org.apache.hadoop.mapreduce.JobID -import org.apache.hadoop.mapreduce.RecordReader -import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.hadoop.mapreduce.TaskAttemptID +import org.apache.hadoop.mapreduce._ -import java.util.Date -import java.text.SimpleDateFormat +import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} -import spark.Dependency -import spark.RDD -import spark.SerializableWritable -import spark.SparkContext -import spark.Split -private[spark] +private[spark] class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) extends Split { - + val serializableHadoopSplit = new SerializableWritable(rawSplit) override def hashCode(): Int = (41 * (41 + rddId) + index) @@ -33,8 +24,9 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) { - + extends RDD[(K, V)](sc) + with HadoopMapReduceUtil { + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) // private val serializableConf = new SerializableWritable(conf) @@ -50,7 +42,7 @@ class NewHadoopRDD[K, V]( @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance - val jobContext = new JobContext(conf, jobId) + val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Split](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -61,15 +53,19 @@ class NewHadoopRDD[K, V]( override def splits = splits_ - override def compute(theSplit: Split) = new Iterator[(K, V)] { + override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = new TaskAttemptContext(conf, attemptId) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance - val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) - reader.initialize(split.serializableHadoopSplit.value, context) - + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => reader.close()) + var havePair = false var finished = false @@ -77,9 +73,6 @@ class NewHadoopRDD[K, V]( if (!finished && !havePair) { finished = !reader.nextKeyValue havePair = !finished - if (finished) { - reader.close() - } } !finished } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 98ea0c92d6..336e193217 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,10 +8,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.OneToOneDependency -import spark.RDD -import spark.SparkEnv -import spark.Split +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} /** @@ -32,12 +29,12 @@ class PipedRDD[T: ClassManifest]( override val dependencies = List(new OneToOneDependency(parent)) - override def compute(split: Split): Iterator[String] = { + override def compute(split: Split, context: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } - + val proc = pb.start() val env = SparkEnv.get @@ -55,7 +52,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, context)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 87a5268f27..6e4797aabb 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -4,9 +4,8 @@ import java.util.Random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import spark.RDD -import spark.OneToOneDependency -import spark.Split +import spark.{OneToOneDependency, RDD, Split, TaskContext} + private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -15,7 +14,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali class SampledRDD[T: ClassManifest]( prev: RDD[T], - withReplacement: Boolean, + withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev.context) { @@ -29,17 +28,17 @@ class SampledRDD[T: ClassManifest]( override def splits = splits_.asInstanceOf[Array[Split]] override val dependencies = List(new OneToOneDependency(prev)) - + override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) - override def compute(splitIn: Split) = { + override def compute(splitIn: Split, context: TaskContext) = { val split = splitIn.asInstanceOf[SampledRDDSplit] if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev).flatMap { element => + prev.iterator(split.prev, context).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) + prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index be120acc71..f832633646 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,30 +1,23 @@ package spark.rdd -import scala.collection.mutable.ArrayBuffer -import java.util.{HashMap => JHashMap} +import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} -import spark.Aggregator -import spark.Partitioner -import spark.RangePartitioner -import spark.RDD -import spark.ShuffleDependency -import spark.SparkEnv -import spark.Split private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx override def hashCode(): Int = idx } - /** * The resulting RDD from a shuffle (e.g. repartitioning of data). + * @param parent the parent RDD. + * @param part the partitioner used to partition the RDD + * @tparam K the key class. + * @tparam V the value class. */ -abstract class ShuffledRDD[K, V, C]( +class ShuffledRDD[K, V]( @transient parent: RDD[(K, V)], - aggregator: Option[Aggregator[K, V, C]], - part: Partitioner) - extends RDD[(K, C)](parent.context) { + part: Partitioner) extends RDD[(K, V)](parent.context) { override val partitioner = Some(part) @@ -35,108 +28,10 @@ abstract class ShuffledRDD[K, V, C]( override def preferredLocations(split: Split) = Nil - val dep = new ShuffleDependency(parent, aggregator, part) + val dep = new ShuffleDependency(parent, part) override val dependencies = List(dep) -} - - -/** - * Repartition a key-value pair RDD. - */ -class RepartitionShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) - extends ShuffledRDD[K, V, V]( - parent, - None, - part) { - - override def compute(split: Split): Iterator[(K, V)] = { - val buf = new ArrayBuffer[(K, V)] - val fetcher = SparkEnv.get.shuffleFetcher - def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) } - fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer) - buf.iterator - } -} - - -/** - * A sort-based shuffle (that doesn't apply aggregation). It does so by first - * repartitioning the RDD by range, and then sort within each range. - */ -class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V]( - @transient parent: RDD[(K, V)], - ascending: Boolean, - numSplits: Int) - extends RepartitionShuffledRDD[K, V]( - parent, - new RangePartitioner(numSplits, parent, ascending)) { - - override def compute(split: Split): Iterator[(K, V)] = { - // By separating this from RepartitionShuffledRDD, we avoided a - // buf.iterator.toArray call, thus avoiding building up the buffer twice. - val buf = new ArrayBuffer[(K, V)] - def addTupleToBuffer(k: K, v: V) { buf += ((k, v)) } - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer) - if (ascending) { - buf.sortWith((x, y) => x._1 < y._1).iterator - } else { - buf.sortWith((x, y) => x._1 > y._1).iterator - } - } -} - - -/** - * The resulting RDD from shuffle and running (hash-based) aggregation. - */ -class ShuffledAggregatedRDD[K, V, C]( - @transient parent: RDD[(K, V)], - aggregator: Aggregator[K, V, C], - part : Partitioner) - extends ShuffledRDD[K, V, C](parent, Some(aggregator), part) { - - override def compute(split: Split): Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - val fetcher = SparkEnv.get.shuffleFetcher - - if (aggregator.mapSideCombine) { - // Apply combiners on map partitions. In this case, post-shuffle we get a - // list of outputs from the combiners and merge them using mergeCombiners. - def mergePairWithMapSideCombiners(k: K, c: C) { - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, aggregator.mergeCombiners(oldC, c)) - } - } - fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners) - } else { - // Do not apply combiners on map partitions (i.e. map side aggregation is - // turned off). Post-shuffle we get a list of values and we use mergeValue - // to merge them. - def mergePairWithoutMapSideCombiners(k: K, v: V) { - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, aggregator.createCombiner(v)) - } else { - combiners.put(k, aggregator.mergeValue(oldC, v)) - } - } - fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners) - } - - return new Iterator[(K, C)] { - var iter = combiners.entrySet().iterator() - - def hasNext: Boolean = iter.hasNext() - def next(): (K, C) = { - val entry = iter.next() - (entry.getKey, entry.getValue) - } - } + override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { + SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index f0b9225f7c..a08473f7be 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -2,20 +2,17 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.Dependency -import spark.RangeDependency -import spark.RDD -import spark.SparkContext -import spark.Split +import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext} + private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, + idx: Int, rdd: RDD[T], split: Split) extends Split with Serializable { - - def iterator() = rdd.iterator(split) + + def iterator(context: TaskContext) = rdd.iterator(split, context) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx } @@ -25,7 +22,7 @@ class UnionRDD[T: ClassManifest]( @transient rdds: Seq[RDD[T]]) extends RDD[T](sc) with Serializable { - + @transient val splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) @@ -44,13 +41,14 @@ class UnionRDD[T: ClassManifest]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) + deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } deps.toList } - - override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + + override def compute(s: Split, context: TaskContext): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator(context) override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala new file mode 100644 index 0000000000..92d667ff1e --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -0,0 +1,53 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext} + + +private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( + idx: Int, + rdd1: RDD[T], + rdd2: RDD[U], + split1: Split, + split2: Split) + extends Split + with Serializable { + + def iterator(context: TaskContext): Iterator[(T, U)] = + rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) + + def preferredLocations(): Seq[String] = + rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) + + override val index: Int = idx +} + +class ZippedRDD[T: ClassManifest, U: ClassManifest]( + sc: SparkContext, + @transient rdd1: RDD[T], + @transient rdd2: RDD[U]) + extends RDD[(T, U)](sc) + with Serializable { + + @transient + val splits_ : Array[Split] = { + if (rdd1.splits.size != rdd2.splits.size) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Split](rdd1.splits.size) + for (i <- 0 until rdd1.splits.size) { + array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i)) + } + array + } + + override def splits = splits_ + + @transient + override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) + + override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = + s.asInstanceOf[ZippedSplit[T, U]].iterator(context) + + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() +} diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 6f4c6bffd7..29757b1178 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId /** - * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for - * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal + * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for + * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ @@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general - + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures @@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { cacheLocs(rdd.id) } - + def updateCacheLocs() { cacheLocs = cacheTracker.getLocationsSnapshot() } @@ -104,7 +104,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = { + def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -119,7 +119,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = { + def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") @@ -149,7 +149,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_,_] => + case shufDep: ShuffleDependency[_,_] => parents += getShuffleMapStage(shufDep, priority) case _ => visit(dep.rdd) @@ -172,7 +172,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (locs(p) == Nil) { for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_,_] => + case shufDep: ShuffleDependency[_,_] => val mapStage = getShuffleMapStage(shufDep, stage.priority) if (!mapStage.isAvailable) { missing += mapStage @@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val rdd = job.finalStage.rdd val split = rdd.splits(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split)) + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + taskContext.executeOnCompleteCallbacks() job.listener.taskSucceeded(0, result) } catch { case e: Exception => @@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } } - + def submitMissingTasks(stage: Stage) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry @@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val task = event.task val stage = idToStage(task.stageId) event.reason match { - case Success => + case Success => logInfo("Completed " + task) if (event.accumUpdates != null) { Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted @@ -479,8 +480,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with ") for resubmision due to a fetch failure") // Mark the map whose fetch failed as broken in the map stage val mapStage = shuffleToMapStage(shuffleId) - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin + "); marking it for resubmission") failed += mapStage @@ -517,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with updateCacheLocs() } } - + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -549,7 +552,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visitedRdds += rdd for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_,_] => + case shufDep: ShuffleDependency[_,_] => val mapStage = getShuffleMapStage(shufDep, stage.priority) if (!mapStage.isAvailable) { visitedStages += mapStage diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 2ebd4075a2..e492279b4e 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -10,12 +10,14 @@ private[spark] class ResultTask[T, U]( @transient locs: Seq[String], val outputId: Int) extends Task[U](stageId) { - + val split = rdd.splits(partition) override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - func(context, rdd.iterator(split)) + val result = func(context, rdd.iterator(split, context)) + context.executeOnCompleteCallbacks() + result } override def preferredLocations: Seq[String] = locs diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 86796d3677..bd1911fce2 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -22,7 +22,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new JHashMap[Int, Array[Byte]] - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId) if (old != null) { @@ -41,14 +41,14 @@ private[spark] object ShuffleMapTask { } } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { synchronized { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] return (rdd, dep) } } @@ -70,19 +70,19 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_,_,_], - var partition: Int, + var rdd: RDD[_], + var dep: ShuffleDependency[_,_], + var partition: Int, @transient var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { def this() = this(0, null, null, 0, null) - + var split = if (rdd == null) { - null - } else { + null + } else { rdd.splits(partition) } @@ -113,33 +113,16 @@ private[spark] class ShuffleMapTask( val numOutputSplits = dep.partitioner.numPartitions val partitioner = dep.partitioner - val bucketIterators = - if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) { - val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]] - // Apply combiners (map-side aggregation) to the map output. - val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - val bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - val existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) - } else { - bucket.put(k, aggregator.mergeValue(existing, v)) - } - } - buckets.map(_.iterator) - } else { - // No combiners (no map-side aggregation). Simply partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) - for (elem <- rdd.iterator(split)) { - val pair = elem.asInstanceOf[(Any, Any)] - val bucketId = partitioner.getPartition(pair._1) - buckets(bucketId) += pair - } - buckets.map(_.iterator) - } + val taskContext = new TaskContext(stageId, partition, attemptId) + + // Partition the map output. + val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split, taskContext)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + val bucketIterators = buckets.map(_.iterator) val compressedSizes = new Array[Byte](numOutputSplits) @@ -152,6 +135,9 @@ private[spark] class ShuffleMapTask( compressedSizes(i) = MapOutputTracker.compressSize(size) } + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() + return new MapStatus(blockManager.blockManagerId, compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 1149c00a23..4846b66729 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -22,7 +22,7 @@ import spark.storage.BlockManagerId private[spark] class Stage( val id: Int, val rdd: RDD[_], - val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage + val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], val priority: Int) extends Logging { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index f5e852d203..20f6e65020 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -249,15 +249,22 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def slaveLost(slaveId: String) { + def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { val host = slaveIdToHost(slaveId) if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) slaveIdsWithExecutors -= slaveId hostsAlive -= host activeTaskSetsQueue.foreach(_.hostLost(host)) failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) } } if (failedHost != None) { diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala new file mode 100644 index 0000000000..bba7de6a65 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala @@ -0,0 +1,21 @@ +package spark.scheduler.cluster + +import spark.executor.ExecutorExitCode + +/** + * Represents an explanation for a executor or whole slave failing or exiting. + */ +private[spark] +class ExecutorLossReason(val message: String) { + override def toString: String = message +} + +private[spark] +case class ExecutorExited(val exitCode: Int) + extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +} + +private[spark] +case class SlaveLost(_message: String = "Slave lost") + extends ExecutorLossReason(_message) { +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7aba7324ab..e2301347e5 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,6 +19,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt + val executorIdToSlaveId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -65,9 +66,23 @@ private[spark] class SparkDeploySchedulerBackend( } def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { + executorIdToSlaveId += id -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( id, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String) {} + def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + val reason: ExecutorLossReason = exitStatus match { + case Some(code) => ExecutorExited(code) + case None => SlaveLost(message) + } + logInfo("Executor %s removed: %s".format(id, message)) + executorIdToSlaveId.get(id) match { + case Some(slaveId) => + executorIdToSlaveId.remove(id) + scheduler.slaveLost(slaveId, reason) + case None => + logInfo("No slave ID known for executor %s".format(id)) + } + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index d2cce0dc05..eeaae23dc8 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -69,13 +69,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToSlaveId.get(actor).foreach(removeSlave) + actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave) + addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave) + addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) } // Make fake resource offers on all slaves @@ -99,7 +99,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String) { + def removeSlave(slaveId: String, reason: String) { logInfo("Slave " + slaveId + " disconnected, so removing it") val numCores = freeCores(slaveId) actorToSlaveId -= slaveActor(slaveId) @@ -109,7 +109,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor freeCores -= slaveId slaveHost -= slaveId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId) + scheduler.slaveLost(slaveId, SlaveLost(reason)) } } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index b84b4dc2ed..2593c0e3a0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -30,12 +30,12 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val currentJars: HashMap[String, Long] = new HashMap[String, Long]() val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - + // TODO: Need to take into account stage priority in scheduling override def start() { } - override def setListener(listener: TaskSchedulerListener) { + override def setListener(listener: TaskSchedulerListener) { this.listener = listener } @@ -78,7 +78,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // on in development (so when users move their local Spark programs // to the cluster, they don't get surprised by serialization errors). val resultToReturn = ser.deserialize[Any](ser.serialize(result)) - val accumUpdates = Accumulators.values + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) logInfo("Finished task " + idInJob) listener.taskEnded(task, Success, resultToReturn, accumUpdates) } catch { @@ -107,26 +108,28 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } } } } - + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index cdfe1f2563..8c7a1dfbc0 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -267,17 +267,23 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { + private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { slaveIdsWithExecutors -= slaveId.getValue } - scheduler.slaveLost(slaveId.toString) + scheduler.slaveLost(slaveId.getValue, reason) + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { + recordSlaveLost(d, slaveId, SlaveLost()) } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + override def executorLost(d: SchedulerDriver, executorId: ExecutorID, + slaveId: SlaveID, status: Int) { + logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, + slaveId.getValue)) + recordSlaveLost(d, slaveId, ExecutorExited(status)) } // TODO: query Mesos for number of cores diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index bd9155ef29..7a8ac10cdd 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -1,67 +1,39 @@ package spark.storage -import akka.dispatch.{Await, Future} -import akka.util.Duration - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput} -import java.nio.{MappedByteBuffer, ByteBuffer} +import java.io.{InputStream, OutputStream} +import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.collection.JavaConversions._ -import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils} -import spark.network._ -import spark.serializer.Serializer -import spark.util.ByteBufferInputStream -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import sun.nio.ch.DirectBuffer - - -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } +import akka.actor.{ActorSystem, Cancellable, Props} +import akka.dispatch.{Await, Future} +import akka.util.Duration +import akka.util.duration._ - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - override def toString = "BlockManagerId(" + ip + ", " + port + ")" +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - override def hashCode = ip.hashCode * 41 + port +import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.network._ +import spark.serializer.Serializer +import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} +import sun.nio.ch.DirectBuffer -private[spark] +private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) - -private[spark] class BlockLocker(numLockers: Int) { - private val hashLocker = Array.fill(numLockers)(new Object()) - - def getLock(blockId: String): Object = { - return hashLocker(math.abs(blockId.hashCode % numLockers)) - } -} - - private[spark] -class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) +class BlockManager( + actorSystem: ActorSystem, + val master: BlockManagerMaster, + val serializer: Serializer, + maxMemory: Long) extends Logging { class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { @@ -87,10 +59,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - private val NUM_LOCKS = 337 - private val locker = new BlockLocker(NUM_LOCKS) - - private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -110,20 +79,38 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val maxBytesInFlight = System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + // Whether to compress broadcast variables that are stored val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + // Whether to compress shuffle output that are stored val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean + val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties + val host = System.getProperty("spark.hostname", Utils.localHostName()) + val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + + @volatile private var shuttingDown = false + + private def heartBeat() { + if (!master.sendHeartBeat(blockManagerId)) { + reregister() + } + } + + var heartBeatTask: Cancellable = null + + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** * Construct a BlockManager with a memory limit set based on system properties. */ - def this(master: BlockManagerMaster, serializer: Serializer) = { - this(master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = { + this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) } /** @@ -131,55 +118,100 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) + if (!BlockManager.getDisableHeartBeatsForTesting) { + heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { + heartBeat() + } + } } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Report all blocks to the BlockManager again. This may be necessary if we are dropped + * by the BlockManager and come back or if we become capable of recovering blocks on disk after + * an executor crash. + * + * This function deliberately fails silently if the master returns false (indicating that + * the slave needs to reregister). The error condition will be detected again by the next + * heart beat attempt or new block registration and another try to reregister all blocks + * will be made then. */ - def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null + private def reportAllBlocks() { + logInfo("Reporting " + blockInfo.size + " blocks to the master.") + for ((blockId, info) <- blockInfo) { + if (!tryToReportBlockStatus(blockId, info)) { + logError("Failed to report " + blockId + " to master; giving up.") + return + } + } } /** - * Tell the master about the current storage status of a block. This will send a heartbeat + * Reregister with the master and report all blocks to it. This will be called by the heart beat + * thread if our heartbeat to the block amnager indicates that we were not registered. + */ + def reregister() { + // TODO: We might need to rate limit reregistering. + logInfo("BlockManager reregistering with master") + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + reportAllBlocks() + } + + /** + * Get storage level of local block. If no info exists for the block, then returns null. + */ + def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + + /** + * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ - def reportBlockStatus(blockId: String) { - locker.getLock(blockId).synchronized { - val curLevel = blockInfo.get(blockId) match { + def reportBlockStatus(blockId: String, info: BlockInfo) { + val needReregister = !tryToReportBlockStatus(blockId, info) + if (needReregister) { + logInfo("Got told to reregister updating block " + blockId) + // Reregistering will report our new block for free. + reregister() + } + logDebug("Told master about block " + blockId) + } + + /** + * Actually send a UpdateBlockInfo message. Returns the mater's response, + * which will be true if the block was successfully recorded and false if + * the slave needs to re-register. + */ + private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { + info.level match { case null => - StorageLevel.NONE - case info => - info.level match { - case null => - StorageLevel.NONE - case level => - val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) - new StorageLevel(onDisk, inMem, level.deserialized, level.replication) - } + (StorageLevel.NONE, 0L, 0L, false) + case level => + val inMem = level.useMemory && memoryStore.contains(blockId) + val onDisk = level.useDisk && diskStore.contains(blockId) + val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + (storageLevel, memSize, diskSize, info.tellMaster) } - master.mustHeartBeat(HeartBeat( - blockManagerId, - blockId, - curLevel, - if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L, - if (curLevel.useDisk) diskStore.getSize(blockId) else 0L)) - logDebug("Told master about block " + blockId) + } + + if (tellMaster) { + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) + } else { + true } } + /** * Get locations of the block. */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -190,8 +222,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -213,9 +244,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -273,9 +304,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } return None } @@ -298,9 +329,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - if (info != null) { + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { info.waitForReady() // In case the block is still being put() by another thread val level = info.level logDebug("Level for block " + blockId + " is " + level) @@ -338,9 +369,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new Exception("Block " + blockId + " not found on disk, though it should be") } } - } else { - logDebug("Block " + blockId + " not registered locally") } + } else { + logDebug("Block " + blockId + " not registered locally") } return None } @@ -354,7 +385,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -556,7 +587,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -583,7 +614,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Size of the block in bytes (to return to caller) var size = 0L - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -611,7 +642,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // and tell the master about it. myInfo.markReady(size) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) @@ -657,7 +688,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -681,7 +712,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m null } - locker.getLock(blockId).synchronized { + myInfo.synchronized { logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") @@ -698,7 +729,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // and tell the master about it. myInfo.markReady(bytes.limit) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } @@ -732,7 +763,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime @@ -779,25 +810,79 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - locker.getLock(blockId).synchronized { - val info = blockInfo.get(blockId) - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + val info = blockInfo.get(blockId).orNull + if (info != null) { + info.synchronized { + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + val blockWasRemoved = memoryStore.remove(blockId) + if (!blockWasRemoved) { + logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + } + if (info.tellMaster) { + reportBlockStatus(blockId, info) + } + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) } } - memoryStore.remove(blockId) + } else { + // The block has already been dropped + } + } + + /** + * Remove a block from both memory and disk. + */ + def removeBlock(blockId: String) { + logInfo("Removing block " + blockId) + val info = blockInfo.get(blockId).orNull + if (info != null) info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning("Block " + blockId + " could not be removed as it was not found in either " + + "the disk or memory store") + } + blockInfo.remove(blockId) if (info.tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, info) } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) + } else { + // The block has already been removed; do nothing. + logWarning("Asked to remove block " + blockId + ", which does not exist") + } + } + + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id, info) } } } @@ -847,7 +932,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } def stop() { + if (heartBeatTask != null) { + heartBeatTask.cancel() + } connectionManager.stop() + master.actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diskStore.clear() @@ -857,11 +946,20 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m private[spark] object BlockManager extends Logging { + + val ID_GENERATOR = new IdGenerator + def getMaxMemoryFromSystemProperties: Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong } + def getHeartBeatFrequencyFromSystemProperties: Long = + System.getProperty("spark.storage.blockManagerHeartBeatMs", "5000").toLong + + def getDisableHeartBeatsForTesting: Boolean = + System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..488679f049 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,48 @@ +package spark.storage + +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.concurrent.ConcurrentHashMap + + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + + +private[spark] object BlockManagerId { + + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 7bfa31ac3d..a3d8671834 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -1,548 +1,167 @@ package spark.storage -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.ArrayBuffer import scala.util.Random -import akka.actor._ -import akka.dispatch._ +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import akka.dispatch.Await import akka.pattern.ask -import akka.remote._ import akka.util.{Duration, Timeout} import akka.util.duration._ import spark.{Logging, SparkException, Utils} -private[spark] -sealed trait ToBlockManagerMaster - -private[spark] -case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long) - extends ToBlockManagerMaster - -private[spark] -class HeartBeat( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var memSize: Long, - var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { +private[spark] class BlockManagerMaster( + val actorSystem: ActorSystem, + isMaster: Boolean, + isLocal: Boolean, + masterIp: String, + masterPort: Int) + extends Logging { - def this() = this(null, null, null, 0, 0) // For deserialization only + val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeInt(memSize.toInt) - out.writeInt(diskSize.toInt) - } + val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" + val DEFAULT_MANAGER_IP: String = Utils.localHostName() - override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) - blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) - memSize = in.readInt() - diskSize = in.readInt() + val timeout = 10.seconds + var masterActor: ActorRef = { + if (isMaster) { + val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = MASTER_AKKA_ACTOR_NAME) + logInfo("Registered BlockManagerMaster Actor") + masterActor + } else { + val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + logInfo("Connecting to BlockManagerMaster: " + url) + actorSystem.actorFor(url) + } } -} -private[spark] -object HeartBeat { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize) + /** Remove a dead host from the master actor. This is only called on the master side. */ + def notifyADeadHost(host: String) { + tell(RemoveHost(host)) + logInfo("Removed " + host + " successfully in notifyADeadHost") } - // For pattern-matching - def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + /** + * Send the master actor a heart beat from the slave. Returns true if everything works out, + * false if the master does not know about the given block manager, which means the block + * manager should re-register. + */ + def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { + askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) } -} - -private[spark] -case class GetLocations(blockId: String) extends ToBlockManagerMaster - -private[spark] -case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster - -private[spark] -case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster - -private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster - -private[spark] -case object StopBlockManagerMaster extends ToBlockManagerMaster - -private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long) { - private var lastSeenMs = timeMs - private var remainingMem = maxMem - private val blocks = new JHashMap[String, StorageLevel] - - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) - - def updateLastSeenMs() { - lastSeenMs = System.currentTimeMillis() / 1000 - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { - - updateLastSeenMs() - - if (blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = blocks.get(blockId) - - if (originalLevel.useMemory) { - remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - blocks.put(blockId, storageLevel) - if (storageLevel.useMemory) { - remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } else if (blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val originalLevel: StorageLevel = blocks.get(blockId) - blocks.remove(blockId) - if (originalLevel.useMemory) { - remainingMem += memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(remainingMem))) - } - if (originalLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } - } - - def getLastSeenMs: Long = { - return lastSeenMs - } - - def getRemainedMem: Long = { - return remainingMem - } - - override def toString: String = { - return "BlockManagerInfo " + timeMs + " " + remainingMem - } - - def clear() { - blocks.clear() - } - } - - private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] - private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] - - initLogging() - - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - val ip = host.split(":")(0) - val port = host.split(":")(1) - blockManagerInfo.remove(new BlockManagerId(ip, port.toInt)) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) - sender ! true + /** Register the BlockManager's id with the master. */ + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + logInfo("Trying to register BlockManager") + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + logInfo("Registered BlockManager") } - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize) => - register(blockManagerId, maxMemSize) - - case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => - heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ - - case RemoveHost(host) => - removeHost(host) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - context.stop(self) - - case other => - logInfo("Got unknown message: " + other) - } - - private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else { - blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - blockManagerId, System.currentTimeMillis() / 1000, maxMemSize)) - } - logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) - sender ! true - } - - private def heartBeat( + def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { - - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - sender ! true - } - - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - - var locations: HashSet[BlockManagerId] = null - if (blockInfo.containsKey(blockId)) { - locations = blockInfo.get(blockId)._2 - } else { - locations = new HashSet[BlockManagerId] - blockInfo.put(blockId, (storageLevel.replication, locations)) - } - - if (storageLevel.isValid) { - locations += blockManagerId - } else { - locations.remove(blockManagerId) - } - - if (locations.size == 0) { - blockInfo.remove(blockId) - } - sender ! true - } - - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " - + Utils.getUsedTimeMs(startTimeMs)) - sender ! res.toSeq - } else { - logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - sender ! res - } - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) - return res.toSeq - } else { - logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) - sender ! res.toSeq + diskSize: Long): Boolean = { + val res = askMasterWithRetry[Boolean]( + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) + res } - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) - } - sender ! res.toSeq + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } - - private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - val peersWithIndices = peers.zipWithIndex - val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) - if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") - } + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + } - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - sender ! res.toSeq + result } -} - -private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) - extends Logging { - - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val DEFAULT_MANAGER_PORT: String = "10902" - val timeout = 10.seconds - var masterActor: ActorRef = null + /** + * Remove a block from the slaves that have it. This can only be used to remove + * blocks that the master knows about. + */ + def removeBlock(blockId: String) { + askMasterWithRetry(RemoveBlock(blockId)) + } - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) - logInfo("Registered BlockManagerMaster Actor") - } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) - logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) + /** + * Return the memory status for each block manager, in the form of a map from + * the block manager's id to two long values. The first value is the maximum + * amount of memory allocated for the block manager, while the second is the + * amount of remaining memory. + */ + def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { + askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - + + /** Stop the master actor, called only on the Spark master node */ def stop() { if (masterActor != null) { - communicate(StopBlockManagerMaster) + tell(StopBlockManagerMaster) masterActor = null logInfo("BlockManagerMaster stopped") } } - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) - } - } - - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") + /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + private def tell(message: Any) { + if (!askMasterWithRetry[Boolean](message)) { + throw new SparkException("BlockManagerMasterActor returned false, expected true.") } } - - def notifyADeadHost(host: String) { - communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)) - logInfo("Removed " + host + " successfully in notifyADeadHost") - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { - logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") - } - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } - } - - def mustHeartBeat(msg: HeartBeat) { - while (! syncHeartBeat(msg)) { - logWarning("Failed to send heartbeat" + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncHeartBeat(msg: HeartBeat): Boolean = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logDebug("Heartbeat sent successfully") - logDebug("Got in syncHeartBeat 1 " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return false - } - } - - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def askMasterWithRetry[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") } - return res - } - - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null + var attempts = 0 + var lastException: Exception = null + while (attempts < AKKA_RETRY_ATTEMPS) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null + Thread.sleep(AKKA_RETRY_INTERVAL_MS) } - } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) - } - return res + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } - } catch { - case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null - } - } - - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res - } - - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null - } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null - } - } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala new file mode 100644 index 0000000000..f4d026da33 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -0,0 +1,401 @@ +package spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ +import scala.util.Random + +import akka.actor.{Actor, ActorRef, Cancellable} +import akka.util.{Duration, Timeout} +import akka.util.duration._ + +import spark.{Logging, Utils} + +/** + * BlockManagerMasterActor is an actor on the master node to track statuses of + * all slaves' block managers. + */ +private[spark] +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = + new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + + // Mapping from host name to block manager id. We allow multiple block managers + // on the same host name (ip). + private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + + initLogging() + + val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", + "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong + + val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", + "5000").toLong + + var timeoutCheckingTask: Cancellable = null + + override def preStart() { + if (!BlockManager.getDisableHeartBeatsForTesting) { + timeoutCheckingTask = context.system.scheduler.schedule( + 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) + } + super.preStart() + } + + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) + + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + getPeersDeterministic(blockManagerId, size) + /*getPeers(blockManagerId, size)*/ + + case GetMemoryStatus => + getMemoryStatus + + case RemoveBlock(blockId) => + removeBlock(blockId) + + case RemoveHost(host) => + removeHost(host) + sender ! true + + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel + } + context.stop(self) + + case ExpireDeadHosts => + expireDeadHosts() + + case HeartBeat(blockManagerId) => + heartBeat(blockManagerId) + + case other => + logInfo("Got unknown message: " + other) + } + + def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByHost. If the list of block + // managers belonging to the IP is empty, remove the entry from the hash map. + blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => + managers -= blockManagerId + if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) + } + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + var iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId)._2 + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(locations) + } + } + } + + def expireDeadHosts() { + logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + val now = System.currentTimeMillis() + val minSeenTime = now - slaveTimeout + val toRemove = new HashSet[BlockManagerId] + for (info <- blockManagerInfo.values) { + if (info.lastSeenMs < minSeenTime) { + logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats") + toRemove += info.blockManagerId + } + } + toRemove.foreach(removeBlockManager) + } + + def removeHost(host: String) { + logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") + logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) + blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) + logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + sender ! true + } + + def heartBeat(blockManagerId: BlockManagerId) { + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + sender ! true + } else { + sender ! false + } + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlock(blockId: String) { + val block = blockLocations.get(blockId) + if (block != null) { + block._2.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveActor ! RemoveBlock(blockId) + } + } + } + sender ! true + } + + // Return a map from the block manager id to max memory and remaining memory. + private def getMemoryStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + sender ! res + } + + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + logInfo("Got Register Msg from master node, don't register it") + } else { + blockManagerIdByHost.get(blockManagerId.ip) match { + case Some(managers) => + // A block manager of the same host name already exists. + logInfo("Got another registration for host " + blockManagerId) + managers += blockManagerId + case None => + blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + } + + blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( + blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + } + sender ! true + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long) { + + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + blockId + " " + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + sender ! true + } else { + sender ! false + } + return + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + return + } + + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) + + var locations: HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId)._2 + } else { + locations = new HashSet[BlockManagerId] + blockLocations.put(blockId, (storageLevel.replication, locations)) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + sender ! true + } + + private def getLocations(blockId: String) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockId + " " + if (blockLocations.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockLocations.get(blockId)._2) + sender ! res.toSeq + } else { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + sender ! res + } + } + + private def getLocationsMultipleBlockIds(blockIds: Array[String]) { + def getLocations(blockId: String): Seq[BlockManagerId] = { + val tmp = blockId + if (blockLocations.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockLocations.get(blockId)._2) + return res.toSeq + } else { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + return res.toSeq + } + } + + var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] + for (blockId <- blockIds) { + res.append(getLocations(blockId)) + } + sender ! res.toSeq + } + + private def getPeers(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(peers) + res -= blockManagerId + val rand = new Random(System.currentTimeMillis()) + while (res.length > size) { + res.remove(rand.nextInt(res.length)) + } + sender ! res.toSeq + } + + private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + + val selfIndex = peers.indexOf(blockManagerId) + if (selfIndex == -1) { + throw new Exception("Self index for " + blockManagerId + " not found") + } + + // Note that this logic will select the same node multiple times if there aren't enough peers + var index = selfIndex + while (res.size < size) { + index += 1 + if (index == selfIndex) { + throw new Exception("More peer expected than available") + } + res += peers(index % peers.size) + } + sender ! res.toSeq + } +} + + +private[spark] +object BlockManagerMasterActor { + + case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + + class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveActor: ActorRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] + + logInfo("Registering block manager %s:%d with %s RAM".format( + blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) + : Unit = synchronized { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + + if (originalLevel.useMemory) { + _remainingMem += memSize + } + } + + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + if (storageLevel.useMemory) { + _remainingMem -= memSize + logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + logInfo("Added %s on disk on %s:%d (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize + logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s:%d on disk (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[String, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala new file mode 100644 index 0000000000..d73a9b790f --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -0,0 +1,102 @@ +package spark.storage + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import akka.actor.ActorRef + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from the master to slaves. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerSlave + +// Remove a block from the slaves that have it. This can only be used to remove +// blocks that the master knows about. +private[spark] +case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from slaves to the master. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerMaster + +private[spark] +case class RegisterBlockManager( + blockManagerId: BlockManagerId, + maxMemSize: Long, + sender: ActorRef) + extends ToBlockManagerMaster + +private[spark] +case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + +private[spark] +class UpdateBlockInfo( + var blockManagerId: BlockManagerId, + var blockId: String, + var storageLevel: StorageLevel, + var memSize: Long, + var diskSize: Long) + extends ToBlockManagerMaster + with Externalizable { + + def this() = this(null, null, null, 0, 0) // For deserialization only + + override def writeExternal(out: ObjectOutput) { + blockManagerId.writeExternal(out) + out.writeUTF(blockId) + storageLevel.writeExternal(out) + out.writeInt(memSize.toInt) + out.writeInt(diskSize.toInt) + } + + override def readExternal(in: ObjectInput) { + blockManagerId = new BlockManagerId() + blockManagerId.readExternal(in) + blockId = in.readUTF() + storageLevel = new StorageLevel() + storageLevel.readExternal(in) + memSize = in.readInt() + diskSize = in.readInt() + } +} + +private[spark] +object UpdateBlockInfo { + def apply(blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) + } + + // For pattern-matching + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + } +} + +private[spark] +case class GetLocations(blockId: String) extends ToBlockManagerMaster + +private[spark] +case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + +private[spark] +case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + +private[spark] +case class RemoveHost(host: String) extends ToBlockManagerMaster + +private[spark] +case object StopBlockManagerMaster extends ToBlockManagerMaster + +private[spark] +case object GetMemoryStatus extends ToBlockManagerMaster + +private[spark] +case object ExpireDeadHosts extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala new file mode 100644 index 0000000000..f570cdc52d --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala @@ -0,0 +1,16 @@ +package spark.storage + +import akka.actor.Actor + +import spark.{Logging, SparkException, Utils} + + +/** + * An actor to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { + override def receive = { + case RemoveBlock(blockId) => blockManager.removeBlock(blockId) + } +} diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 096bf8bdd9..8188d3595e 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -31,7 +31,12 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { def getValues(blockId: String): Option[Iterator[Any]] - def remove(blockId: String) + /** + * Remove a block, if it exists. + * @param blockId the block to remove. + * @return True if the block was found and removed, False otherwise. + */ + def remove(blockId: String): Boolean def contains(blockId: String): Boolean diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 8ba64e4b76..7e5b820cbb 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -10,6 +10,8 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import scala.collection.mutable.ArrayBuffer +import spark.executor.ExecutorExitCode + import spark.Utils /** @@ -90,10 +92,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } - override def remove(blockId: String) { + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { file.delete() + true + } else { + false } } @@ -162,7 +167,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) if (!foundLocalDir) { logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create local dir in " + rootDir) - System.exit(1) + System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } logInfo("Created local directory at " + localDir) localDir diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 773970446a..00e32f753c 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -18,12 +18,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L + // Object used to ensure that only one thread is putting blocks and if necessary, dropping + // blocks from the memory store. + private val putLock = new Object() + logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory override def getSize(blockId: String): Long = { - synchronized { + entries.synchronized { entries.get(blockId).size } } @@ -37,9 +41,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) tryToPut(blockId, elements, sizeEstimate, true) } else { - val entry = new Entry(bytes, bytes.limit, false) - ensureFreeSpace(blockId, bytes.limit) - synchronized { entries.put(blockId, entry) } tryToPut(blockId, bytes, bytes.limit, false) } } @@ -63,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -76,7 +77,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = synchronized { + val entry = entries.synchronized { entries.get(blockId) } if (entry == null) { @@ -89,22 +90,23 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String) { - synchronized { + override def remove(blockId: String): Boolean = { + entries.synchronized { val entry = entries.get(blockId) if (entry != null) { entries.remove(blockId) currentMemory -= entry.size logInfo("Block %s of size %d dropped from memory (free %d)".format( blockId, entry.size, freeMemory)) + true } else { - logWarning("Block " + blockId + " could not be removed as it does not exist") + false } } } override def clear() { - synchronized { + entries.synchronized { entries.clear() } logInfo("MemoryStore cleared") @@ -125,12 +127,22 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Try to put in a set of values, if we can free up enough space. The value should either be * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) * size must also be passed by the caller. + * + * Locks on the object putLock to ensure that all the put requests and its associated block + * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * blocks to free memory for one block, another thread may use up the freed space for + * another block. */ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - synchronized { + // TODO: Its possible to optimize the locking by locking entries only when selecting blocks + // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been + // released, it must be ensured that those to-be-dropped blocks are not double counted for + // freeing up more space for another block that needs to be put. Only then the actually dropping + // of blocks (and writing to disk if necessary) can proceed in parallel. + putLock.synchronized { if (ensureFreeSpace(blockId, size)) { val entry = new Entry(value, size, deserialized) - entries.put(blockId, entry) + entries.synchronized { entries.put(blockId, entry) } currentMemory += size if (deserialized) { logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( @@ -160,10 +172,11 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space - * might fill up before the caller puts in their new value.) + * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. + * Otherwise, the freed space may fill up before the caller puts in their new value. */ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) @@ -172,36 +185,44 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } - // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks - // in order to allow parallelism in writing to disk if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) val selectedBlocks = new ArrayBuffer[String]() var selectedMemory = 0L - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.get(blockId) - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) } - blockManager.dropFromMemory(blockId, data) } return true } else { @@ -212,7 +233,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def contains(blockId: String): Boolean = { - synchronized { entries.containsKey(blockId) } + entries.synchronized { entries.containsKey(blockId) } } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..e3544e5aae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,6 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -10,14 +10,16 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} * commonly useful storage levels. */ class StorageLevel( - var useDisk: Boolean, + var useDisk: Boolean, var useMemory: Boolean, var deserialized: Boolean, var replication: Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -29,14 +31,14 @@ class StorageLevel( override def equals(other: Any): Boolean = other match { case s: StorageLevel => - s.useDisk == useDisk && + s.useDisk == useDisk && s.useMemory == useMemory && s.deserialized == deserialized && - s.replication == replication + s.replication == replication case _ => false } - + def isValid = ((useMemory || useDisk) && (replication > 0)) def toInt: Int = { @@ -66,10 +68,16 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + override def hashCode(): Int = toInt * 41 + replication } + object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) @@ -82,4 +90,16 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + private[spark] + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala new file mode 100644 index 0000000000..689f07b969 --- /dev/null +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -0,0 +1,96 @@ +package spark.storage + +import akka.actor._ + +import spark.KryoSerializer +import java.util.concurrent.ArrayBlockingQueue +import util.Random + +/** + * This class tests the BlockManager and MemoryStore for thread safety and + * deadlocks. It spawns a number of producer and consumer threads. Producer + * threads continuously pushes blocks into the BlockManager and consumer + * threads continuously retrieves the blocks form the BlockManager and tests + * whether the block is correct or not. + */ +private[spark] object ThreadingTest { + + val numProducers = 5 + val numBlocksPerProducer = 20000 + + private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { + val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) + + override def run() { + for (i <- 1 to numBlocksPerProducer) { + val blockId = "b-" + id + "-" + i + val blockSize = Random.nextInt(1000) + val block = (1 to blockSize).map(_ => Random.nextInt()) + val level = randomLevel() + val startTime = System.currentTimeMillis() + manager.put(blockId, block.iterator, level, true) + println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") + queue.add((blockId, block)) + } + println("Producer thread " + id + " terminated") + } + + def randomLevel(): StorageLevel = { + math.abs(Random.nextInt()) % 4 match { + case 0 => StorageLevel.MEMORY_ONLY + case 1 => StorageLevel.MEMORY_ONLY_SER + case 2 => StorageLevel.MEMORY_AND_DISK + case 3 => StorageLevel.MEMORY_AND_DISK_SER + } + } + } + + private[spark] class ConsumerThread( + manager: BlockManager, + queue: ArrayBlockingQueue[(String, Seq[Int])] + ) extends Thread { + var numBlockConsumed = 0 + + override def run() { + println("Consumer thread started") + while(numBlockConsumed < numBlocksPerProducer) { + val (blockId, block) = queue.take() + val startTime = System.currentTimeMillis() + manager.get(blockId) match { + case Some(retrievedBlock) => + assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, + "Block " + blockId + " did not match") + println("Got block " + blockId + " in " + + (System.currentTimeMillis - startTime) + " ms") + case None => + assert(false, "Block " + blockId + " could not be retrieved") + } + numBlockConsumed += 1 + } + println("Consumer thread terminated") + } + } + + def main(args: Array[String]) { + System.setProperty("spark.kryoserializer.buffer.mb", "1") + val actorSystem = ActorSystem("test") + val serializer = new KryoSerializer + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) + val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) + val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) + val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) + producers.foreach(_.start) + consumers.foreach(_.start) + producers.foreach(_.join) + consumers.foreach(_.join) + blockManager.stop() + blockManagerMaster.stop() + actorSystem.shutdown() + actorSystem.awaitTermination() + println("Everything stopped.") + println( + "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") + } +} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index b466b5239c..e67cb0336d 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -25,6 +25,8 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt + val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -32,10 +34,11 @@ private[spark] object AkkaUtils { akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d - akka.remote.netty.connection-timeout = 1s + akka.remote.netty.connection-timeout = %ds + akka.remote.netty.message-frame-size = %d MiB akka.remote.netty.execution-pool-size = %d akka.actor.default-dispatcher.throughput = %d - """.format(host, port, akkaThreads, akkaBatchSize)) + """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..b6e309fe1a --- /dev/null +++ b/core/src/main/scala/spark/util/IdGenerator.scala @@ -0,0 +1,14 @@ +package spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..19e67acd0c --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,35 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + + val delaySeconds = (System.getProperty("spark.cleanup.delay", "-100").toDouble * 60).toInt + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + if (periodSeconds > 0) { + logInfo( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..070ee19ac0 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,87 @@ +package spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, Map} + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the cleanup method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + jIterator.map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + internalMap.remove(key) + this + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + internalMap.map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def cleanup(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala index 4e95ac2ac6..03559751bc 100644 --- a/core/src/main/scala/spark/util/Vector.scala +++ b/core/src/main/scala/spark/util/Vector.scala @@ -49,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable { return ans } - def +=(other: Vector) { + def += (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") var ans = 0.0 @@ -58,6 +58,7 @@ class Vector(val elements: Array[Double]) extends Serializable { elements(i) += other(i) i += 1 } + this } def * (scale: Double): Vector = Vector(length, i => this(i) * scale) diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index 7562076b00..18c32e5a1f 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -1,5 +1,6 @@ @(state: spark.deploy.MasterState) @import spark.deploy.master._ +@import spark.Utils @spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) { @@ -8,9 +9,11 @@ <div class="span12"> <ul class="unstyled"> <li><strong>URL:</strong> spark://@(state.uri)</li> - <li><strong>Number of Workers:</strong> @state.workers.size </li> - <li><strong>Cores:</strong> @state.workers.map(_.cores).sum Total, @state.workers.map(_.coresUsed).sum Used</li> - <li><strong>Memory:</strong> @state.workers.map(_.memory).sum Total, @state.workers.map(_.memoryUsed).sum Used</li> + <li><strong>Workers:</strong> @state.workers.size </li> + <li><strong>Cores:</strong> @{state.workers.map(_.cores).sum} Total, + @{state.workers.map(_.coresUsed).sum} Used</li> + <li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(state.workers.map(_.memory).sum)} Total, + @{Utils.memoryMegabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li> <li><strong>Jobs:</strong> @state.activeJobs.size Running, @state.completedJobs.size Completed </li> </ul> </div> @@ -21,7 +24,7 @@ <div class="span12"> <h3> Cluster Summary </h3> <br/> - @worker_table(state.workers) + @worker_table(state.workers.sortBy(_.id)) </div> </div> @@ -32,7 +35,7 @@ <div class="span12"> <h3> Running Jobs </h3> <br/> - @job_table(state.activeJobs) + @job_table(state.activeJobs.sortBy(_.startTime).reverse) </div> </div> @@ -43,7 +46,7 @@ <div class="span12"> <h3> Completed Jobs </h3> <br/> - @job_table(state.completedJobs) + @job_table(state.completedJobs.sortBy(_.endTime).reverse) </div> </div> diff --git a/core/src/main/twirl/spark/deploy/master/job_row.scala.html b/core/src/main/twirl/spark/deploy/master/job_row.scala.html index 7c4865bb6e..7c466a6a2c 100644 --- a/core/src/main/twirl/spark/deploy/master/job_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/job_row.scala.html @@ -1,20 +1,20 @@ @(job: spark.deploy.master.JobInfo) +@import spark.Utils +@import spark.deploy.WebUI.formatDate +@import spark.deploy.WebUI.formatDuration + <tr> <td> <a href="job?jobId=@(job.id)">@job.id</a> </td> <td>@job.desc.name</td> <td> - @job.coresGranted Granted - @if(job.desc.cores == Integer.MAX_VALUE) { - - } else { - , @job.coresLeft - } + @job.coresGranted </td> - <td>@job.desc.memoryPerSlave</td> - <td>@job.submitDate</td> + <td>@Utils.memoryMegabytesToString(job.desc.memoryPerSlave)</td> + <td>@formatDate(job.submitDate)</td> <td>@job.desc.user</td> <td>@job.state.toString()</td> -</tr>
\ No newline at end of file + <td>@formatDuration(job.duration)</td> +</tr> diff --git a/core/src/main/twirl/spark/deploy/master/job_table.scala.html b/core/src/main/twirl/spark/deploy/master/job_table.scala.html index 52bad6c4b8..d267d6e85e 100644 --- a/core/src/main/twirl/spark/deploy/master/job_table.scala.html +++ b/core/src/main/twirl/spark/deploy/master/job_table.scala.html @@ -1,4 +1,4 @@ -@(jobs: List[spark.deploy.master.JobInfo]) +@(jobs: Array[spark.deploy.master.JobInfo]) <table class="table table-bordered table-striped table-condensed sortable"> <thead> @@ -6,10 +6,11 @@ <th>JobID</th> <th>Description</th> <th>Cores</th> - <th>Memory per Slave</th> - <th>Submit Date</th> + <th>Memory per Node</th> + <th>Submit Time</th> <th>User</th> <th>State</th> + <th>Duration</th> </tr> </thead> <tbody> @@ -17,4 +18,4 @@ @job_row(j) } </tbody> -</table>
\ No newline at end of file +</table> diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html index 017cc4859e..be69e9bf02 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html @@ -1,11 +1,14 @@ @(worker: spark.deploy.master.WorkerInfo) +@import spark.Utils + <tr> <td> - <a href="http://@worker.host:@worker.webUiPort">@worker.id</href> + <a href="@worker.webUiAddress">@worker.id</href> </td> <td>@{worker.host}:@{worker.port}</td> + <td>@worker.state</td> <td>@worker.cores (@worker.coresUsed Used)</td> - <td>@{spark.Utils.memoryMegabytesToString(worker.memory)} - (@{spark.Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td> + <td>@{Utils.memoryMegabytesToString(worker.memory)} + (@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td> </tr> diff --git a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html index 2028842297..b249411a62 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html @@ -1,10 +1,11 @@ -@(workers: List[spark.deploy.master.WorkerInfo]) +@(workers: Array[spark.deploy.master.WorkerInfo]) <table class="table table-bordered table-striped table-condensed sortable"> <thead> <tr> <th>ID</th> <th>Address</th> + <th>State</th> <th>Cores</th> <th>Memory</th> </tr> @@ -14,4 +15,4 @@ @worker_row(w) } </tbody> -</table>
\ No newline at end of file +</table> diff --git a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html index c3842dbf85..ea9542461e 100644 --- a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html @@ -1,20 +1,20 @@ @(executor: spark.deploy.worker.ExecutorRunner) +@import spark.Utils + <tr> <td>@executor.execId</td> <td>@executor.cores</td> - <td>@executor.memory</td> + <td>@Utils.memoryMegabytesToString(executor.memory)</td> <td> <ul class="unstyled"> <li><strong>ID:</strong> @executor.jobId</li> <li><strong>Name:</strong> @executor.jobDesc.name</li> <li><strong>User:</strong> @executor.jobDesc.user</li> - <li><strong>Cores:</strong> @executor.jobDesc.cores </li> - <li><strong>Memory per Slave:</strong> @executor.jobDesc.memoryPerSlave</li> </ul> </td> <td> <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stdout">stdout</a> <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stderr">stderr</a> </td> -</tr>
\ No newline at end of file +</tr> diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index 69746ed02c..b247307dab 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,5 +1,7 @@ @(worker: spark.deploy.WorkerState) +@import spark.Utils + @spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) { <!-- Worker Details --> @@ -12,8 +14,8 @@ (WebUI at <a href="@worker.masterWebUiUrl">@worker.masterWebUiUrl</a>) </li> <li><strong>Cores:</strong> @worker.cores (@worker.coresUsed Used)</li> - <li><strong>Memory:</strong> @{spark.Utils.memoryMegabytesToString(worker.memory)} - (@{spark.Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li> + <li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(worker.memory)} + (@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li> </ul> </div> </div> diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 5875506179..33d5fc2d89 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -1,5 +1,12 @@ package spark; +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; + import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; @@ -12,8 +19,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import scala.Tuple2; - import spark.api.java.JavaDoubleRDD; import spark.api.java.JavaPairRDD; import spark.api.java.JavaRDD; @@ -24,10 +29,6 @@ import spark.partial.PartialResult; import spark.storage.StorageLevel; import spark.util.StatCounter; -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -44,6 +45,8 @@ public class JavaAPISuite implements Serializable { public void tearDown() { sc.stop(); sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); } static class ReverseIntComparator implements Comparator<Integer>, Serializable { @@ -128,6 +131,17 @@ public class JavaAPISuite implements Serializable { } @Test + public void lookup() { + JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList( + new Tuple2<String, String>("Apples", "Fruit"), + new Tuple2<String, String>("Oranges", "Fruit"), + new Tuple2<String, String>("Oranges", "Citrus") + )); + Assert.assertEquals(2, categories.lookup("Oranges").size()); + Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); + } + + @Test public void groupBy() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() { @@ -381,7 +395,8 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0)).next().intValue()); + TaskContext context = new TaskContext(0, 0, 0); + Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } @Test @@ -553,4 +568,17 @@ public class JavaAPISuite implements Serializable { } }).collect().toString()); } + + @Test + public void zip() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() { + @Override + public Double call(Integer x) { + return 1.0 * x; + } + }); + JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles); + zipped.count(); + } } diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 4e9717d871..5b4b198960 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -2,10 +2,14 @@ package spark import org.scalatest.FunSuite +import akka.actor._ +import spark.scheduler.MapStatus +import spark.storage.BlockManagerId + class MapOutputTrackerSuite extends FunSuite { test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) - assert(MapOutputTracker.compressSize(1L) === 0) + assert(MapOutputTracker.compressSize(1L) === 1) assert(MapOutputTracker.compressSize(2L) === 8) assert(MapOutputTracker.compressSize(10L) === 25) assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) @@ -15,11 +19,58 @@ class MapOutputTrackerSuite extends FunSuite { } test("decompressSize") { - assert(MapOutputTracker.decompressSize(0) === 1) + assert(MapOutputTracker.decompressSize(0) === 0) for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) assert(size2 >= 0.99 * size && size2 <= 1.11 * size, "size " + size + " decompressed to " + size2 + ", which is out of range") } } + + test("master start and stop") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker(actorSystem, true) + tracker.stop() + } + + test("master register and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker(actorSystem, true) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + Array(compressedSize10000, compressedSize1000))) + val statuses = tracker.getServerStatuses(10, 0) + assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), + (new BlockManagerId("hostB", 1000), size10000))) + tracker.stop() + } + + test("master register and unregister and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker(actorSystem, true) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + Array(compressedSize1000, compressedSize1000, compressedSize1000))) + tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + Array(compressedSize10000, compressedSize1000, compressedSize1000))) + + // As if we had two simulatenous fetch failures + tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + + // The remaining reduce task might try to grab the output dispite the shuffle failure; + // this should cause it to fail, and the scheduler will ignore the failure due to the + // stage already being aborted. + intercept[Exception] { tracker.getServerStatuses(10, 1) } + } } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..08da9a1c4d 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -8,9 +8,9 @@ import spark.rdd.CoalescedRDD import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { - + var sc: SparkContext = _ - + after { if (sc != null) { sc.stop() @@ -19,11 +19,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + test("basic operations") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + assert(dups.distinct.count === 4) + assert(dups.distinct().collect === dups.distinct.collect) + assert(dups.distinct(2).collect === dups.distinct.collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) @@ -114,4 +118,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(coalesced4.glom().collect().map(_.toList).toList === (1 to 10).map(x => List(x)).toList) } + + test("zipped RDDs") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 7f8ec5d48f..8170100f1d 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -12,8 +12,8 @@ import org.scalacheck.Prop._ import com.google.common.io.Files -import spark.rdd.ShuffledAggregatedRDD -import SparkContext._ +import spark.rdd.ShuffledRDD +import spark.SparkContext._ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { @@ -216,41 +216,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } - - test("map-side combine") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1), (1, 1)), 2) - - // Test with map-side combine on. - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 8), (2, 1))) - - // Turn off map-side combine and test the results. - val aggregator = new Aggregator[Int, Int, Int]( - (v: Int) => v, - _+_, - _+_, - false) - val shuffledRdd = new ShuffledAggregatedRDD( - pairs, aggregator, new HashPartitioner(2)) - assert(shuffledRdd.collect().toSet === Set((1, 8), (2, 1))) - - // Turn map-side combine off and pass a wrong mergeCombine function. Should - // not see an exception because mergeCombine should not have been called. - val aggregatorWithException = new Aggregator[Int, Int, Int]( - (v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false) - val shuffledRdd1 = new ShuffledAggregatedRDD( - pairs, aggregatorWithException, new HashPartitioner(2)) - assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1))) - - // Now run the same mergeCombine function with map-side combine on. We - // expect to see an exception thrown. - val aggregatorWithException1 = new Aggregator[Int, Int, Int]( - (v: Int) => v, _+_, ShuffleSuite.mergeCombineException) - val shuffledRdd2 = new ShuffledAggregatedRDD( - pairs, aggregatorWithException1, new HashPartitioner(2)) - evaluating { shuffledRdd2.collect() } should produce [SparkException] - } } object ShuffleSuite { diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b9c19e61cd..8f86e3170e 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -7,6 +7,10 @@ import akka.actor._ import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.time.SpanSugar._ import spark.KryoSerializer import spark.SizeEstimator @@ -14,21 +18,25 @@ import spark.util.ByteBufferInputStream class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { var store: BlockManager = null + var store2: BlockManager = null var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null var oldOops: String = null - - // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + var oldHeartBeat: String = null + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer before { actorSystem = ActorSystem("test") - master = new BlockManagerMaster(actorSystem, true, true) + master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077) - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") oldOops = System.setProperty("spark.test.useCompressedOops", "true") + oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() } @@ -36,6 +44,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT after { if (store != null) { store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null } actorSystem.shutdown() actorSystem.awaitTermination() @@ -55,8 +68,34 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } - test("manager-master interaction") { - store = new BlockManager(master, serializer, 2000) + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + + test("master + 1 manager interaction") { + store = new BlockManager(actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -66,27 +105,126 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) - // Checking whether blocks are in memory + // Checking whether blocks are in memory assert(store.getSingle("a1") != None, "a1 was not in store") assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") != None, "a3 was not in store") // Checking whether master knows about the blocks or not - assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") - + assert(master.getLocations("a1").size > 0, "master was not told about a1") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") + // Drop a1 and a2 from memory; this should be reported back to the master store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager(actorSystem, master, serializer, 2000) + store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") + } + + test("removing block") { + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false) + + // Checking whether blocks are in memory and memory size + val memStatus = master.getMemoryStatus.head._2 + assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") + assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") + assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") + assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") + assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") + assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") + assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") + + // Remove a1 and a2 and a3. Should be no-op for a3. + master.removeBlock("a1-to-remove") + master.removeBlock("a2-to-remove") + master.removeBlock("a3-to-remove") + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a1-to-remove") should be (None) + master.getLocations("a1-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a2-to-remove") should be (None) + master.getLocations("a2-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a3-to-remove") should not be (None) + master.getLocations("a3-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + val memStatus = master.getMemoryStatus.head._2 + memStatus._1 should equal (2000L) + memStatus._2 should equal (2000L) + } + } + + test("reregistration on heart beat") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.notifyADeadHost(store.blockManagerId.ip) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store invokePrivate heartBeat() + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + } + + test("reregistration on block update") { + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.notifyADeadHost(store.blockManagerId.ip) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + assert(master.getLocations("a2").size > 0, "master was not told about a2") } test("in-memory LRU storage") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -103,9 +241,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a2") != None, "a2 was not in store") assert(store.getSingle("a3") === None, "a3 was in store") } - + test("in-memory LRU storage with serialization") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -124,7 +262,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -143,7 +281,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -166,7 +304,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -179,7 +317,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -194,7 +332,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -209,7 +347,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -224,7 +362,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -239,7 +377,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -264,7 +402,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -288,7 +426,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager(master, serializer, 1200) + store = new BlockManager(actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -334,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager(master, serializer, 500) + store = new BlockManager(actorSystem, master, serializer, 500) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -345,49 +483,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager(master, serializer, 2000) + store = new BlockManager(actorSystem, master, serializer, 2000) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() |