aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2014-01-09 13:42:04 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2014-01-09 13:42:04 -0800
commit6f713e2a3e56185368b66fb087637dec112a1f5d (patch)
tree201400e576fb2dd27ff5362e91de23df4401f69d
parenta17cc602ac79b22457ed457023493fe82e9d39df (diff)
downloadspark-6f713e2a3e56185368b66fb087637dec112a1f5d.tar.gz
spark-6f713e2a3e56185368b66fb087637dec112a1f5d.tar.bz2
spark-6f713e2a3e56185368b66fb087637dec112a1f5d.zip
Changed the way StreamingContext finds and reads checkpoint files, and added JavaStreamingContext.getOrCreate.
-rw-r--r--conf/slaves6
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java7
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala43
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala98
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala57
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala64
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala96
10 files changed, 254 insertions, 125 deletions
diff --git a/conf/slaves b/conf/slaves
index 30ea300e07..2fbb50c4a8 100644
--- a/conf/slaves
+++ b/conf/slaves
@@ -1,5 +1 @@
-ec2-54-221-59-252.compute-1.amazonaws.com
-ec2-67-202-26-243.compute-1.amazonaws.com
-ec2-23-22-220-97.compute-1.amazonaws.com
-ec2-50-16-98-100.compute-1.amazonaws.com
-ec2-54-234-164-206.compute-1.amazonaws.com
+localhost
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 7514ce58fb..304e85f1c0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -29,7 +29,7 @@ import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{LocalFileSystem, Path}
import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
index def87c199b..d8d6046914 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
@@ -41,17 +41,17 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class JavaNetworkWordCount {
public static void main(String[] args) {
if (args.length < 3) {
- System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
+ System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1");
System.exit(1);
}
// Create the context with a 1 second batch size
- JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount",
new Duration(1000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2]));
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
@@ -74,6 +74,5 @@ public class JavaNetworkWordCount {
wordCounts.print();
ssc.start();
-
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
index e2487dca5f..5ad4875980 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
@@ -44,7 +44,7 @@ object NetworkWordCount {
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(1), args(2).toInt)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
index 0e5f39f772..739f805e87 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.streaming.examples
import org.apache.spark.streaming.{Time, Seconds, StreamingContext}
@@ -8,20 +25,37 @@ import org.apache.spark.rdd.RDD
import com.google.common.io.Files
import java.nio.charset.Charset
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-directory>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ * <checkpoint-directory> directory in a Hadoop compatible file system to which checkpoint
+ * data will be saved to; this must be a fault-tolerant file system
+ * like HDFS for the system to recover from driver failures
+ * <checkpoint-
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run-example org.apache.spark.streaming.examples.NetworkWordCount local[2] localhost 9999`
+ */
+
object RecoverableNetworkWordCount {
def createContext(master: String, ip: String, port: Int, outputPath: String) = {
+ // If you do not see this printed, that means the StreamingContext has been loaded
+ // from the new checkpoint
+ println("Creating new context")
val outputFile = new File(outputPath)
if (outputFile.exists()) outputFile.delete()
// Create the context with a 1 second batch size
- println("Creating new context")
val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
val lines = ssc.socketTextStream(ip, port)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
@@ -39,10 +73,10 @@ object RecoverableNetworkWordCount {
System.err.println("You arguments were " + args.mkString("[", ", ", "]"))
System.err.println(
"""
- |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-directory>
+ |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file>
|
|In local mode, <master> should be 'local[n]' with n > 1
- |Both <checkpoint-directory> and <output-directory> should be full paths
+ |Both <checkpoint-directory> and <output-file> should be full paths
""".stripMargin
)
System.exit(1)
@@ -53,6 +87,5 @@ object RecoverableNetworkWordCount {
createContext(master, ip, port, outputPath)
})
ssc.start()
-
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 155d5bc02e..a32e4852c5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -82,22 +82,28 @@ class CheckpointWriter(jobGenerator: JobGenerator, checkpointDir: String, hadoop
attempts += 1
try {
logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
- // This is inherently thread unsafe, so alleviating it by writing to '.new' and
+ // This is inherently thread unsafe, so alleviating it by writing to '.next' and
// then moving it to the final file
val fos = fs.create(writeFile)
fos.write(bytes)
fos.close()
+
+ // Back up existing checkpoint if it exists
if (fs.exists(file) && fs.rename(file, bakFile)) {
logDebug("Moved existing checkpoint file to " + bakFile)
}
- // paranoia
- fs.delete(file, false)
- fs.rename(writeFile, file)
-
- val finishTime = System.currentTimeMillis()
- logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
- "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
- jobGenerator.onCheckpointCompletion(checkpointTime)
+ fs.delete(file, false) // paranoia
+
+ // Rename temp written file to the right location
+ if (fs.rename(writeFile, file)) {
+ val finishTime = System.currentTimeMillis()
+ logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
+ "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms")
+ jobGenerator.onCheckpointCompletion(checkpointTime)
+ } else {
+ throw new SparkException("Failed to rename checkpoint file from "
+ + writeFile + " to " + file)
+ }
return
} catch {
case ioe: IOException =>
@@ -154,47 +160,47 @@ class CheckpointWriter(jobGenerator: JobGenerator, checkpointDir: String, hadoop
private[streaming]
object CheckpointReader extends Logging {
- def doesCheckpointExist(path: String): Boolean = {
- val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"))
- val fs = new Path(path).getFileSystem(new Configuration())
- (attempts.count(p => fs.exists(p)) > 1)
- }
+ private val graphFileNames = Seq("graph", "graph.bk")
+
+ def read(checkpointDir: String, hadoopConf: Configuration): Option[Checkpoint] = {
+ val checkpointPath = new Path(checkpointDir)
+ def fs = checkpointPath.getFileSystem(hadoopConf)
+ val existingFiles = graphFileNames.map(new Path(checkpointPath, _)).filter(fs.exists)
+
+ // Log the file listing if graph checkpoint file was not found
+ if (existingFiles.isEmpty) {
+ logInfo("Could not find graph file in " + checkpointDir + ", which contains the files:\n" +
+ fs.listStatus(checkpointPath).mkString("\n"))
+ return None
+ }
+ logInfo("Checkpoint files found: " + existingFiles.mkString(","))
- def read(path: String): Checkpoint = {
- val fs = new Path(path).getFileSystem(new Configuration())
- val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"))
val compressionCodec = CompressionCodec.createCodec()
-
- attempts.foreach(file => {
- if (fs.exists(file)) {
- logInfo("Attempting to load checkpoint from file '" + file + "'")
- try {
- val fis = fs.open(file)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val zis = compressionCodec.compressedInputStream(fis)
- val ois = new ObjectInputStreamWithLoader(zis,
- Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- logInfo("Checkpoint successfully loaded from file '" + file + "'")
- logInfo("Checkpoint was generated at time " + cp.checkpointTime)
- return cp
- } catch {
- case e: Exception =>
- logError("Error loading checkpoint from file '" + file + "'", e)
- }
- } else {
- logWarning("Could not read checkpoint from file '" + file + "' as it does not exist")
+ existingFiles.foreach(file => {
+ logInfo("Attempting to load checkpoint from file '" + file + "'")
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val zis = compressionCodec.compressedInputStream(fis)
+ val ois = new ObjectInputStreamWithLoader(zis,
+ Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ logInfo("Checkpoint successfully loaded from file '" + file + "'")
+ logInfo("Checkpoint was generated at time " + cp.checkpointTime)
+ return Some(cp)
+ } catch {
+ case e: Exception =>
+ logWarning("Error reading checkpoint from file '" + file + "'", e)
}
-
})
- throw new SparkException("Could not read checkpoint from path '" + path + "'")
+ throw new SparkException("Failed to read checkpoint from directory '" + checkpointDir + "'")
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
index e0567a1c19..1081d3c807 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
@@ -27,18 +27,16 @@ import org.apache.spark.Logging
import java.io.{ObjectInputStream, IOException}
-
private[streaming]
class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
- @transient private var allCheckpointFiles = new HashMap[Time, String]
- @transient private var timeToLastCheckpointFileTime = new HashMap[Time, Time]
+ // Mapping of the batch time to the checkpointed RDD file of that time
+ @transient private var timeToCheckpointFile = new HashMap[Time, String]
+ // Mapping of the batch time to the time of the oldest checkpointed RDD in that batch's checkpoint data
+ @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time]
@transient private var fileSystem : FileSystem = null
-
- //@transient private var lastCheckpointFiles: HashMap[Time, String] = null
-
protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]]
/**
@@ -51,17 +49,14 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
// Get the checkpointed RDDs from the generated RDDs
val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
.map(x => (x._1, x._2.getCheckpointFile.get))
+ logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n"))
- logInfo("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n"))
- // Make a copy of the existing checkpoint data (checkpointed RDDs)
- // lastCheckpointFiles = checkpointFiles.clone()
-
- // If the new checkpoint data has checkpoints then replace existing with the new one
+ // Add the checkpoint files to the data to be serialized
if (!currentCheckpointFiles.isEmpty) {
currentCheckpointFiles.clear()
currentCheckpointFiles ++= checkpointFiles
- allCheckpointFiles ++= currentCheckpointFiles
- timeToLastCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering)
+ timeToCheckpointFile ++= currentCheckpointFiles
+ timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering)
}
}
@@ -71,32 +66,10 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
* implementation, cleans up old checkpoint files.
*/
def cleanup(time: Time) {
- /*
- // If there is at least on checkpoint file in the current checkpoint files,
- // then delete the old checkpoint files.
- if (checkpointFiles.size > 0 && lastCheckpointFiles != null) {
- (lastCheckpointFiles -- checkpointFiles.keySet).foreach {
- case (time, file) => {
- try {
- val path = new Path(file)
- if (fileSystem == null) {
- fileSystem = path.getFileSystem(new Configuration())
- }
- fileSystem.delete(path, true)
- logInfo("Deleted checkpoint file '" + file + "' for time " + time)
- } catch {
- case e: Exception =>
- logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
- }
- }
- }
- }
- */
- timeToLastCheckpointFileTime.remove(time) match {
+ timeToOldestCheckpointFileTime.remove(time) match {
case Some(lastCheckpointFileTime) =>
- logInfo("Deleting all files before " + time)
- val filesToDelete = allCheckpointFiles.filter(_._1 < lastCheckpointFileTime)
- logInfo("Files to delete:\n" + filesToDelete.mkString(","))
+ val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime)
+ logDebug("Files to delete:\n" + filesToDelete.mkString(","))
filesToDelete.foreach {
case (time, file) =>
try {
@@ -105,11 +78,12 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration)
}
fileSystem.delete(path, true)
- allCheckpointFiles -= time
+ timeToCheckpointFile -= time
logInfo("Deleted checkpoint file '" + file + "' for time " + time)
} catch {
case e: Exception =>
logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
+ fileSystem = null
}
}
case None =>
@@ -138,7 +112,8 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream) {
- timeToLastCheckpointFileTime = new HashMap[Time, Time]
- allCheckpointFiles = new HashMap[Time, String]
+ ois.defaultReadObject()
+ timeToOldestCheckpointFileTime = new HashMap[Time, Time]
+ timeToCheckpointFile = new HashMap[Time, String]
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index bfedef2e4e..34919d315c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -130,11 +130,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
def clearCheckpointData(time: Time) {
- logInfo("Restoring checkpoint data")
+ logInfo("Clearing checkpoint data for time " + time)
this.synchronized {
outputStreams.foreach(_.clearCheckpointData(time))
}
- logInfo("Restored checkpoint data")
+ logInfo("Cleared checkpoint data for time " + time)
}
def restoreCheckpointData() {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 59d2d546e6..30deba417e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -45,10 +45,11 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{LocalFileSystem, Path}
import twitter4j.Status
import twitter4j.auth.Authorization
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -89,10 +90,12 @@ class StreamingContext private (
/**
* Re-create a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ * @param hadoopConf Optional, configuration object if necessary for reading from
+ * HDFS compatible filesystems
*/
- def this(path: String) = this(null, CheckpointReader.read(path), null)
+ def this(path: String, hadoopConf: Configuration = new Configuration) =
+ this(null, CheckpointReader.read(path, hadoopConf).get, null)
initLogging()
@@ -170,8 +173,9 @@ class StreamingContext private (
/**
* Set the context to periodically checkpoint the DStream operations for master
- * fault-tolerance. The graph will be checkpointed every batch interval.
- * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * fault-tolerance.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored.
+ * Note that this must be a fault-tolerant file system like HDFS for
*/
def checkpoint(directory: String) {
if (directory != null) {
@@ -577,6 +581,10 @@ class StreamingContext private (
}
}
+/**
+ * StreamingContext object contains a number of utility functions related to the
+ * StreamingContext class.
+ */
object StreamingContext extends Logging {
@@ -584,19 +592,45 @@ object StreamingContext extends Logging {
new PairDStreamFunctions[K, V](stream)
}
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new StreamingContext
+ * @param hadoopConf Optional Hadoop configuration if necessary for reading from the
+ * file system
+ * @param createOnError Optional, whether to create a new StreamingContext if there is an
+ * error in reading checkpoint data. By default, an exception will be
+ * thrown on error.
+ */
def getOrCreate(
checkpointPath: String,
creatingFunc: () => StreamingContext,
- createOnCheckpointError: Boolean = false
+ hadoopConf: Configuration = new Configuration(),
+ createOnError: Boolean = false
): StreamingContext = {
- if (CheckpointReader.doesCheckpointExist(checkpointPath)) {
- logInfo("Creating streaming context from checkpoint file")
- new StreamingContext(checkpointPath)
- } else {
- logInfo("Creating new streaming context")
- val ssc = creatingFunc()
- ssc.checkpoint(checkpointPath)
- ssc
+
+ try {
+ CheckpointReader.read(checkpointPath, hadoopConf) match {
+ case Some(checkpoint) =>
+ return new StreamingContext(null, checkpoint, null)
+ case None =>
+ logInfo("Creating new StreamingContext")
+ return creatingFunc()
+ }
+ } catch {
+ case e: Exception =>
+ if (createOnError) {
+ logWarning("Error reading checkpoint", e)
+ logInfo("Creating new StreamingContext")
+ return creatingFunc()
+ } else {
+ logError("Error reading checkpoint", e)
+ throw e
+ }
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index aad0d931e7..f38d145317 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -40,6 +40,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.scheduler.StreamingListener
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -125,10 +126,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Re-creates a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this (new StreamingContext(path))
+ def this(path: String) = this(new StreamingContext(path))
+
+ /**
+ * Re-creates a StreamingContext from a checkpoint file.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ *
+ */
+ def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf))
/** The underlying SparkContext */
val sc: JavaSparkContext = new JavaSparkContext(ssc.sc)
@@ -699,13 +706,92 @@ class JavaStreamingContext(val ssc: StreamingContext) {
}
/**
- * Starts the execution of the streams.
+ * Start the execution of the streams.
*/
def start() = ssc.start()
/**
- * Sstops the execution of the streams.
+ * Stop the execution of the streams.
*/
def stop() = ssc.stop()
+}
+
+/**
+ * JavaStreamingContext object contains a number of static utility functions.
+ */
+object JavaStreamingContext {
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ })
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ * @param createOnError Whether to create a new JavaStreamingContext if there is an
+ * error in reading checkpoint data.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory,
+ createOnError: Boolean
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf, createOnError)
+ new JavaStreamingContext(ssc)
+ }
+}
+
+/**
+ * Factory interface for creating a new JavaStreamingContext
+ */
+trait JavaStreamingContextFactory {
+ def create(): JavaStreamingContext
}