diff options
15 files changed, 600 insertions, 205 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1647d904a2..139048d5c7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1086,7 +1086,7 @@ object SparkContext { * parameters that are passed as the default value of null, instead of throwing an exception * like SparkConf would. */ - private def updatedConf( + private[spark] def updatedConf( conf: SparkConf, master: String, appName: String, diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 181ae2fd45..8e07a0f29a 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -26,16 +26,23 @@ import org.apache.spark.Logging /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. + * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be + * updated when it is accessed */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { +class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends Map[A, B]() with Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { val value = internalMap.get(key) - if (value != null) Some(value._1) else None + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, (value._1, currentTime)) + } + Option(value).map(_._1) } def iterator: Iterator[(A, B)] = { diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java index 2e616b1ab2..349d826ab5 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -48,7 +48,7 @@ public final class JavaNetworkWordCount { public static void main(String[] args) { if (args.length < 3) { - System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" + + System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" + "In local mode, <master> should be 'local[n]' with n > 1"); System.exit(1); } @@ -56,12 +56,12 @@ public final class JavaNetworkWordCount { StreamingExamples.setStreamingLogLevels(); // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount", new Duration(1000), System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class)); // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() { @Override @@ -84,6 +84,5 @@ public final class JavaNetworkWordCount { wordCounts.print(); ssc.start(); - } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala index c12139b3ec..25f7013307 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala @@ -21,7 +21,8 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ /** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Counts words in text encoded with UTF8 received from the network every second. + * * Usage: NetworkWordCount <master> <hostname> <port> * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. @@ -46,7 +47,7 @@ object NetworkWordCount { System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') val lines = ssc.socketTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala new file mode 100644 index 0000000000..d51e6e9418 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.util.IntParam +import java.io.File +import org.apache.spark.rdd.RDD +import com.google.common.io.Files +import java.nio.charset.Charset + +/** + * Counts words in text encoded with UTF8 received from the network every second. + * + * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + * <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + * <output-file> file to which the word counts will be appended + * + * In local mode, <master> should be 'local[n]' with n > 1 + * <checkpoint-directory> and <output-file> must be absolute paths + * + * + * To run this on your local machine, you need to first run a Netcat server + * + * `$ nc -lk 9999` + * + * and run the example as + * + * `$ ./run-example org.apache.spark.streaming.examples.RecoverableNetworkWordCount \ + * local[2] localhost 9999 ~/checkpoint/ ~/out` + * + * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + * the checkpoint data. + * + * To run this example in a local standalone cluster with automatic driver recovery, + * + * `$ ./spark-class org.apache.spark.deploy.Client -s launch <cluster-url> <path-to-examples-jar> \ + * org.apache.spark.streaming.examples.RecoverableNetworkWordCount <cluster-url> \ + * localhost 9999 ~/checkpoint ~/out` + * + * <path-to-examples-jar> would typically be <spark-dir>/examples/target/scala-XX/spark-examples....jar + * + * Refer to the online documentation for more details. + */ + +object RecoverableNetworkWordCount { + + def createContext(master: String, ip: String, port: Int, outputPath: String) = { + + // If you do not see this printed, that means the StreamingContext has been loaded + // from the new checkpoint + println("Creating new context") + val outputFile = new File(outputPath) + if (outputFile.exists()) outputFile.delete() + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1), + System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + val lines = ssc.socketTextStream(ip, port) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.foreach((rdd: RDD[(String, Int)], time: Time) => { + val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") + println(counts) + println("Appending to " + outputFile.getAbsolutePath) + Files.append(counts + "\n", outputFile, Charset.defaultCharset()) + }) + ssc + } + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println( + """ + |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + | <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + | <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + | <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + | <output-file> file to which the word counts will be appended + | + |In local mode, <master> should be 'local[n]' with n > 1 + |Both <checkpoint-directory> and <output-file> must be absolute paths + """.stripMargin + ) + System.exit(1) + } + val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args + val ssc = StreamingContext.getOrCreate(checkpointDirectory, + () => { + createContext(master, ip, port, outputPath) + }) + ssc.start() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ca0115f90e..1249ef4c3d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -24,10 +24,10 @@ import java.util.concurrent.RejectedExecutionException import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.MetadataCleaner -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] @@ -44,6 +44,10 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.conf + // These should be unset when a checkpoint is deserialized, + // otherwise the SparkContext won't initialize correctly. + sparkConf.remove("spark.hostPort").remove("spark.driver.host").remove("spark.driver.port") + def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -53,59 +57,119 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } +private[streaming] +object Checkpoint extends Logging { + val PREFIX = "checkpoint-" + val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r + + /** Get the checkpoint file for the given checkpoint time */ + def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) + } + + /** Get the checkpoint backup file for the given checkpoint time */ + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") + } + + /** Get checkpoint files present in the give directory, ordered by oldest-first */ + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { + val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + (time1 < time2) || (time1 == time2 && bk1) + } + + val path = new Path(checkpointDir) + if (fs.exists(path)) { + val statuses = fs.listStatus(path) + if (statuses != null) { + val paths = statuses.map(_.getPath) + val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty) + filtered.sortWith(sortFunc) + } else { + logWarning("Listing " + path + " returned null") + Seq.empty + } + } else { + logInfo("Checkpoint directory " + path + " does not exist") + Seq.empty + } + } +} + /** * Convenience class to handle the writing of graph checkpoint to file */ private[streaming] -class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Configuration) - extends Logging -{ - val file = new Path(checkpointDir, "graph") +class CheckpointWriter( + jobGenerator: JobGenerator, + conf: SparkConf, + checkpointDir: String, + hadoopConf: Configuration + ) extends Logging { val MAX_ATTEMPTS = 3 val executor = Executors.newFixedThreadPool(1) val compressionCodec = CompressionCodec.createCodec(conf) - // The file to which we actually write - and then "move" to file - val writeFile = new Path(file.getParent, file.getName + ".next") - // The file to which existing checkpoint is backed up (i.e. "moved") - val bakFile = new Path(file.getParent, file.getName + ".bk") - private var stopped = false private var fs_ : FileSystem = _ - // Removed code which validates whether there is only one CheckpointWriter per path 'file' since - // I did not notice any errors - reintroduce it ? class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() + val tempFile = new Path(checkpointDir, "temp") + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 try { - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - // This is inherently thread unsafe, so alleviating it by writing to '.new' and - // then moving it to the final file - val fos = fs.create(writeFile) + logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'") + + // Write checkpoint to temp file + fs.delete(tempFile, true) // just in case it exists + val fos = fs.create(tempFile) fos.write(bytes) fos.close() - if (fs.exists(file) && fs.rename(file, bakFile)) { - logDebug("Moved existing checkpoint file to " + bakFile) + + // If the checkpoint file exists, back it up + // If the backup exists as well, just delete it, otherwise rename will fail + if (fs.exists(checkpointFile)) { + fs.delete(backupFile, true) // just in case it exists + if (!fs.rename(checkpointFile, backupFile)) { + logWarning("Could not rename " + checkpointFile + " to " + backupFile) + } + } + + // Rename temp file to the final checkpoint file + if (!fs.rename(tempFile, checkpointFile)) { + logWarning("Could not rename " + tempFile + " to " + checkpointFile) + } + + // Delete old checkpoint files + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + if (allCheckpointFiles.size > 4) { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + logInfo("Deleting " + file) + fs.delete(file, true) + }) } - // paranoia - fs.delete(file, false) - fs.rename(writeFile, file) + // All done, print success val finishTime = System.currentTimeMillis() - logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + - "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") + jobGenerator.onCheckpointCompletion(checkpointTime) return } catch { case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe) reset() } } - logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") + logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'") } } @@ -118,6 +182,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi bos.close() try { executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -140,7 +205,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi } private def fs = synchronized { - if (fs_ == null) fs_ = file.getFileSystem(hadoopConf) + if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf) fs_ } @@ -153,43 +218,46 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi private[streaming] object CheckpointReader extends Logging { - def read(conf: SparkConf, path: String): Checkpoint = { - val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), - new Path(path), new Path(path + ".bk")) + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { + val checkpointPath = new Path(checkpointDir) + def fs = checkpointPath.getFileSystem(hadoopConf) + + // Try to find the checkpoint files + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + if (checkpointFiles.isEmpty) { + return None + } + // Try to read the checkpoint files in the order + logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) val compressionCodec = CompressionCodec.createCodec(conf) - - attempts.foreach(file => { - if (fs.exists(file)) { - logInfo("Attempting to load checkpoint from file '" + file + "'") - try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - logInfo("Checkpoint successfully loaded from file '" + file + "'") - logInfo("Checkpoint was generated at time " + cp.checkpointTime) - return cp - } catch { - case e: Exception => - logError("Error loading checkpoint from file '" + file + "'", e) - } - } else { - logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") + checkpointFiles.foreach(file => { + logInfo("Attempting to load checkpoint from file " + file) + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + logInfo("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint was generated at time " + cp.checkpointTime) + return Some(cp) + } catch { + case e: Exception => + logWarning("Error reading checkpoint from file " + file, e) } - }) - throw new Exception("Could not read checkpoint from path '" + path + "'") + + // If none of checkpoint files could be read, then throw exception + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 837f1ea1d8..b98f4a5101 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -329,13 +329,12 @@ abstract class DStream[T: ClassTag] ( * implementation clears the old generated RDDs. Subclasses of DStream may override * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def clearOldMetadata(time: Time) { - var numForgotten = 0 + protected[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) generatedRDDs --= oldRDDs.keys logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) - dependencies.foreach(_.clearOldMetadata(time)) + dependencies.foreach(_.clearMetadata(time)) } /* Adds metadata to the Stream while it is running. @@ -356,12 +355,18 @@ abstract class DStream[T: ClassTag] ( */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) - checkpointData.update() + checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - checkpointData.cleanup() logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } + protected[streaming] def clearCheckpointData(time: Time) { + logInfo("Clearing checkpoint data") + checkpointData.cleanup(time) + dependencies.foreach(_.clearCheckpointData(time)) + logInfo("Cleared checkpoint data") + } + /** * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method * that should not be called directly. This is a default implementation that recreates RDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala index 3fd5d52403..671f7bbce7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala @@ -17,77 +17,86 @@ package org.apache.spark.streaming +import scala.collection.mutable.{HashMap, HashSet} +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.conf.Configuration -import collection.mutable.HashMap import org.apache.spark.Logging -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - +import java.io.{ObjectInputStream, IOException} private[streaming] class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() - @transient private var fileSystem : FileSystem = null - @transient private var lastCheckpointFiles: HashMap[Time, String] = null + // Mapping of the batch time to the checkpointed RDD file of that time + @transient private var timeToCheckpointFile = new HashMap[Time, String] + // Mapping of the batch time to the time of the oldest checkpointed RDD + // in that batch's checkpoint data + @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time] - protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] + @transient private var fileSystem : FileSystem = null + protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]] /** * Updates the checkpoint data of the DStream. This gets called every time * the graph checkpoint is initiated. Default implementation records the * checkpoint files to which the generate RDDs of the DStream has been saved. */ - def update() { + def update(time: Time) { // Get the checkpointed RDDs from the generated RDDs - val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - lastCheckpointFiles = checkpointFiles.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newCheckpointFiles.size > 0) { - checkpointFiles.clear() - checkpointFiles ++= newCheckpointFiles - } - - // TODO: remove this, this is just for debugging - newCheckpointFiles.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n")) + + // Add the checkpoint files to the data to be serialized + if (!checkpointFiles.isEmpty) { + currentCheckpointFiles.clear() + currentCheckpointFiles ++= checkpointFiles + // Add the current checkpoint files to the map of all checkpoint files + // This will be used to delete old checkpoint files + timeToCheckpointFile ++= currentCheckpointFiles + // Remember the time of the oldest checkpoint RDD in current state + timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering) } } /** - * Cleanup old checkpoint data. This gets called every time the graph - * checkpoint is initiated, but after `update` is called. Default - * implementation, cleans up old checkpoint files. + * Cleanup old checkpoint data. This gets called after a checkpoint of `time` has been + * written to the checkpoint directory. */ - def cleanup() { - // If there is at least on checkpoint file in the current checkpoint files, - // then delete the old checkpoint files. - if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { - (lastCheckpointFiles -- checkpointFiles.keySet).foreach { - case (time, file) => { - try { - val path = new Path(file) - if (fileSystem == null) { - fileSystem = path.getFileSystem(new Configuration()) + def cleanup(time: Time) { + // Get the time of the oldest checkpointed RDD that was written as part of the + // checkpoint of `time` + timeToOldestCheckpointFileTime.remove(time) match { + case Some(lastCheckpointFileTime) => + // Find all the checkpointed RDDs (i.e. files) that are older than `lastCheckpointFileTime` + // This is because checkpointed RDDs older than this are not going to be needed + // even after master fails, as the checkpoint data of `time` does not refer to those files + val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime) + logDebug("Files to delete:\n" + filesToDelete.mkString(",")) + filesToDelete.foreach { + case (time, file) => + try { + val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration) + } + fileSystem.delete(path, true) + timeToCheckpointFile -= time + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + fileSystem = null } - fileSystem.delete(path, true) - logInfo("Deleted checkpoint file '" + file + "' for time " + time) - } catch { - case e: Exception => - logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) - } } - } + case None => + logInfo("Nothing to delete") } } @@ -98,7 +107,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) */ def restore() { // Create RDDs from the checkpoint data - checkpointFiles.foreach { + currentCheckpointFiles.foreach { case(time, file) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) @@ -107,6 +116,13 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } override def toString() = { - "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" + "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + timeToOldestCheckpointFileTime = new HashMap[Time, Time] + timeToCheckpointFile = new HashMap[Time, String] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 62d07b22c6..eee9591ffc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -104,36 +104,44 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } def generateJobs(time: Time): Seq[Job] = { - this.synchronized { - logDebug("Generating jobs for time " + time) - val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) - logDebug("Generated " + jobs.length + " jobs for time " + time) - jobs + logDebug("Generating jobs for time " + time) + val jobs = this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } + logDebug("Generated " + jobs.length + " jobs for time " + time) + jobs } - def clearOldMetadata(time: Time) { + def clearMetadata(time: Time) { + logDebug("Clearing metadata for time " + time) this.synchronized { - logDebug("Clearing old metadata for time " + time) - outputStreams.foreach(_.clearOldMetadata(time)) - logDebug("Cleared old metadata for time " + time) + outputStreams.foreach(_.clearMetadata(time)) } + logDebug("Cleared old metadata for time " + time) } def updateCheckpointData(time: Time) { + logInfo("Updating checkpoint data for time " + time) this.synchronized { - logInfo("Updating checkpoint data for time " + time) outputStreams.foreach(_.updateCheckpointData(time)) - logInfo("Updated checkpoint data for time " + time) } + logInfo("Updated checkpoint data for time " + time) + } + + def clearCheckpointData(time: Time) { + logInfo("Clearing checkpoint data for time " + time) + this.synchronized { + outputStreams.foreach(_.clearCheckpointData(time)) + } + logInfo("Cleared checkpoint data for time " + time) } def restoreCheckpointData() { + logInfo("Restoring checkpoint data") this.synchronized { - logInfo("Restoring checkpoint data") outputStreams.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") } + logInfo("Restored checkpoint data") } def validate() { @@ -146,8 +154,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { + logDebug("DStreamGraph.writeObject used") this.synchronized { - logDebug("DStreamGraph.writeObject used") checkpointInProgress = true oos.defaultWriteObject() checkpointInProgress = false @@ -156,8 +164,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { + logDebug("DStreamGraph.readObject used") this.synchronized { - logDebug("DStreamGraph.readObject used") checkpointInProgress = true ois.defaultReadObject() checkpointInProgress = false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 693cb7fc30..dd34f6f4f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.MetadataCleaner import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receivers._ import org.apache.spark.streaming.scheduler._ +import org.apache.hadoop.conf.Configuration /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -88,10 +89,12 @@ class StreamingContext private ( /** * Re-create a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory + * @param hadoopConf Optional, configuration object if necessary for reading from + * HDFS compatible filesystems */ - def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null) + def this(path: String, hadoopConf: Configuration = new Configuration) = + this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + @@ -171,8 +174,9 @@ class StreamingContext private ( /** * Set the context to periodically checkpoint the DStream operations for master - * fault-tolerance. The graph will be checkpointed every batch interval. - * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + * fault-tolerance. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. + * Note that this must be a fault-tolerant file system like HDFS for */ def checkpoint(directory: String) { if (directory != null) { @@ -461,26 +465,64 @@ class StreamingContext private ( } } +/** + * StreamingContext object contains a number of utility functions related to the + * StreamingContext class. + */ -object StreamingContext { +object StreamingContext extends Logging { implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc`. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext + * @param hadoopConf Optional Hadoop configuration if necessary for reading from the + * file system + * @param createOnError Optional, whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: () => StreamingContext, + hadoopConf: Configuration = new Configuration(), + createOnError: Boolean = false + ): StreamingContext = { + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } + checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) + } + + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls) + protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. - val sc = new SparkContext(conf) - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) + if (MetadataCleaner.getDelaySeconds(conf) < 0) { + MetadataCleaner.setDelaySeconds(conf, 3600) } + val sc = new SparkContext(conf) sc } @@ -489,14 +531,17 @@ object StreamingContext { appName: String, sparkHome: String, jars: Seq[String], - environment: Map[String, String]): SparkContext = - { - val sc = new SparkContext(master, appName, sparkHome, jars, environment) + environment: Map[String, String] + ): SparkContext = { + + val conf = SparkContext.updatedConf( + new SparkConf(), master, appName, sparkHome, jars, environment) // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) + if (MetadataCleaner.getDelaySeconds(conf) < 0) { + MetadataCleaner.setDelaySeconds(conf, 3600) } + val sc = new SparkContext(master, appName, sparkHome, jars, environment) sc } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7068f32517..523173d45a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.hadoop.conf.Configuration /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -128,10 +129,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Re-creates a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this (new StreamingContext(path)) + def this(path: String) = this(new StreamingContext(path)) + + /** + * Re-creates a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + * + */ + def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf)) /** The underlying SparkContext */ val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) @@ -471,20 +478,97 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Starts the execution of the streams. + * Start the execution of the streams. */ def start() = ssc.start() /** - * Sstops the execution of the streams. + * Stop the execution of the streams. */ def stop() = ssc.stop() } +/** + * JavaStreamingContext object contains a number of utility functions. + */ object JavaStreamingContext { + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray } + +/** + * Factory interface for creating a new JavaStreamingContext + */ +trait JavaStreamingContextFactory { + def create(): JavaStreamingContext +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index fb9eda8996..1f0f31c4b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -23,10 +23,10 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time} +import org.apache.spark.util.TimeStampedHashMap private[streaming] @@ -46,6 +46,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null @transient private[streaming] var files = new HashMap[Time, Array[String]] + @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true) + @transient private var lastNewFileFindingTime = 0L override def start() { if (newFilesOnly) { @@ -88,14 +90,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } /** Clear the old time-to-files mappings along with old RDDs */ - protected[streaming] override def clearOldMetadata(time: Time) { - super.clearOldMetadata(time) + protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) val oldFiles = files.filter(_._1 <= (time - rememberDuration)) files --= oldFiles.keys logInfo("Cleared " + oldFiles.size + " old files that were older than " + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + // Delete file mod times that weren't accessed in the last round of getting new files + fileModTimes.clearOldValues(lastNewFileFindingTime - 1) } /** @@ -104,8 +108,19 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = { logDebug("Trying to get new files for time " + currentTime) + lastNewFileFindingTime = System.currentTimeMillis val filter = new CustomPathFilter(currentTime) - val newFiles = fs.listStatus(path, filter).map(_.getPath.toString) + val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + logInfo("Finding new files took " + timeTaken + " ms") + logDebug("# cached file times = " + fileModTimes.size) + if (timeTaken > slideDuration.milliseconds) { + logWarning( + "Time taken to find new files exceeds the batch size. " + + "Consider increasing the batch size or reduceing the number of " + + "files in the monitored directory." + ) + } (newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq) } @@ -122,16 +137,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas new UnionRDD(context.sparkContext, fileRDDs) } - private def path: Path = { + private def directoryPath: Path = { if (path_ == null) path_ = new Path(directory) path_ } private def fs: FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration()) fs_ } + private def getFileModTime(path: Path) = { + // Get file mod time from cache or fetch it from the file system + fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) + } + private def reset() { fs_ = null } @@ -142,6 +162,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[(K,V)]] () files = new HashMap[Time, Array[String]] + fileModTimes = new TimeStampedHashMap[String, Long](true) } /** @@ -153,15 +174,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] - override def update() { + override def update(time: Time) { hadoopFiles.clear() hadoopFiles ++= files } - override def cleanup() { } + override def cleanup(time: Time) { } override def restore() { - hadoopFiles.foreach { + hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, f) => { // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + @@ -187,14 +208,13 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Latest file mod time seen in this round of fetching files and its corresponding files var latestModTime = 0L val latestModTimeFiles = new HashSet[String]() - def accept(path: Path): Boolean = { try { if (!filter(path)) { // Reject file if it does not satisfy filter logDebug("Rejected by filter " + path) return false } - val modTime = fs.getFileStatus(path).getModificationTime() + val modTime = getFileModTime(path) logDebug("Mod time for " + path + " is " + modTime) if (modTime < prevModTime) { logDebug("Mod time less than last mod time") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 3c624e8199..2fa6853ae0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -26,8 +26,9 @@ import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent -private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** * This class generates jobs from DStreams as well as drives checkpointing and cleaning @@ -53,7 +54,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, longTime => eventProcessorActor ! GenerateJobs(new Time(longTime))) lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } @@ -77,15 +78,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * On batch completion, clear old metadata and checkpoint computation. */ private[scheduler] def onBatchCompletion(time: Time) { - eventProcessorActor ! ClearOldMetadata(time) + eventProcessorActor ! ClearMetadata(time) + } + + private[streaming] def onCheckpointCompletion(time: Time) { + eventProcessorActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { event match { case GenerateJobs(time) => generateJobs(time) - case ClearOldMetadata(time) => clearOldMetadata(time) + case ClearMetadata(time) => clearMetadata(time) case DoCheckpoint(time) => doCheckpoint(time) + case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -115,14 +121,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val checkpointTime = ssc.initialCheckpoint.checkpointTime val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time: " + downTimes.mkString(", ")) + logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", ")) // Batches that were unprocessed before failure - val pendingTimes = ssc.initialCheckpoint.pendingTimes - logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) + logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => jobScheduler.runJobs(time, graph.generateJobs(time)) ) @@ -141,11 +147,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } /** Clear DStream metadata for the given `time`. */ - private def clearOldMetadata(time: Time) { - ssc.graph.clearOldMetadata(time) + private def clearMetadata(time: Time) { + ssc.graph.clearMetadata(time) eventProcessorActor ! DoCheckpoint(time) } + /** Clear DStream checkpoint data for the given `time`. */ + private def clearCheckpointData(time: Time) { + ssc.graph.clearCheckpointData(time) + } + /** Perform checkpoint for the give `time`. */ private def doCheckpoint(time: Time) = synchronized { if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala index 1559f7a9f7..162b19d7f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -42,6 +42,7 @@ object MasterFailureTest extends Logging { @volatile var killed = false @volatile var killCount = 0 + @volatile var setupCalled = false def main(args: Array[String]) { if (args.size < 2) { @@ -131,8 +132,26 @@ object MasterFailureTest extends Logging { // Just making sure that the expected output does not have duplicates assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + // Setup the stream computation with the given operation - val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + setupStreams(batchDuration, operation, checkpointDir, testDir) + }) + + // Check if setupStream was called to create StreamingContext + // (and not created from checkpoint file) + assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate") // Start generating files in the a different thread val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) @@ -144,9 +163,7 @@ object MasterFailureTest extends Logging { val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) - // Delete directories fileGeneratingThread.join() - val fs = checkpointDir.getFileSystem(new Configuration()) fs.delete(checkpointDir, true) fs.delete(testDir, true) logInfo("Finished test after " + killCount + " failures") @@ -159,32 +176,24 @@ object MasterFailureTest extends Logging { * files should be written for testing. */ private def setupStreams[T: ClassTag]( - directory: String, batchDuration: Duration, - operation: DStream[String] => DStream[T] - ): (StreamingContext, Path, Path) = { - // Reset all state - reset() - - // Create the directories for this test - val uuid = UUID.randomUUID().toString - val rootDir = new Path(directory, uuid) - val fs = rootDir.getFileSystem(new Configuration()) - val checkpointDir = new Path(rootDir, "checkpoint") - val testDir = new Path(rootDir, "test") - fs.mkdirs(checkpointDir) - fs.mkdirs(testDir) + operation: DStream[String] => DStream[T], + checkpointDir: Path, + testDir: Path + ): StreamingContext = { + // Mark that setup was called + setupCalled = true // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") System.clearProperty("spark.hostPort") - var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) + val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream) ssc.registerOutputStream(outputStream) - (ssc, checkpointDir, testDir) + ssc } @@ -204,7 +213,7 @@ object MasterFailureTest extends Logging { var isTimedOut = false val mergedOutput = new ArrayBuffer[T]() val checkpointDir = ssc.checkpointDir - var batchDuration = ssc.graph.batchDuration + val batchDuration = ssc.graph.batchDuration while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer @@ -261,7 +270,10 @@ object MasterFailureTest extends Logging { ) Thread.sleep(sleepTime) // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir) + ssc = StreamingContext.getOrCreate(checkpointDir, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) } } mergedOutput @@ -297,6 +309,7 @@ object MasterFailureTest extends Logging { private def reset() { killed = false killCount = 0 + setupCalled = false } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8dc80ac2ed..6499de98c9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -84,9 +84,9 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() advanceTimeWithRealDelay(ssc, firstNumBatches) logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -95,7 +95,7 @@ class CheckpointSuite extends TestSuiteBase { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) + val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2)) advanceTimeWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) @@ -114,9 +114,9 @@ class CheckpointSuite extends TestSuiteBase { // is present in the checkpoint data or not ssc.start() advanceTimeWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") |