aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala18
2 files changed, 26 insertions, 4 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 7b343d2376..139e2c08b5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.RejectedExecutionException
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.MetadataCleaner
@@ -141,9 +141,15 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
private[streaming]
object CheckpointReader extends Logging {
+ def doesCheckpointExist(path: String): Boolean = {
+ val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"))
+ val fs = new Path(path).getFileSystem(new Configuration())
+ (attempts.count(p => fs.exists(p)) > 1)
+ }
+
def read(path: String): Checkpoint = {
val fs = new Path(path).getFileSystem(new Configuration())
- val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk"))
+ val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"))
val compressionCodec = CompressionCodec.createCodec()
@@ -175,7 +181,7 @@ object CheckpointReader extends Logging {
}
})
- throw new Exception("Could not read checkpoint from path '" + path + "'")
+ throw new SparkException("Could not read checkpoint from path '" + path + "'")
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 41da028a3c..01b213ab42 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -570,12 +570,28 @@ class StreamingContext private (
}
-object StreamingContext {
+object StreamingContext extends Logging {
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: () => StreamingContext,
+ createOnCheckpointError: Boolean = false
+ ): StreamingContext = {
+ if (CheckpointReader.doesCheckpointExist(checkpointPath)) {
+ logInfo("Creating streaming context from checkpoint file")
+ new StreamingContext(checkpointPath)
+ } else {
+ logInfo("Creating new streaming context")
+ val ssc = creatingFunc()
+ ssc.checkpoint(checkpointPath)
+ ssc
+ }
+ }
+
protected[streaming] def createNewSparkContext(
master: String,
appName: String,