diff options
Diffstat (limited to 'streaming')
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 12 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala | 18 |
2 files changed, 26 insertions, 4 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7b343d2376..139e2c08b5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -24,7 +24,7 @@ import java.util.concurrent.RejectedExecutionException import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.MetadataCleaner @@ -141,9 +141,15 @@ class CheckpointWriter(checkpointDir: String) extends Logging { private[streaming] object CheckpointReader extends Logging { + def doesCheckpointExist(path: String): Boolean = { + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk")) + val fs = new Path(path).getFileSystem(new Configuration()) + (attempts.count(p => fs.exists(p)) > 1) + } + def read(path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk")) val compressionCodec = CompressionCodec.createCodec() @@ -175,7 +181,7 @@ object CheckpointReader extends Logging { } }) - throw new Exception("Could not read checkpoint from path '" + path + "'") + throw new SparkException("Could not read checkpoint from path '" + path + "'") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 41da028a3c..01b213ab42 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -570,12 +570,28 @@ class StreamingContext private ( } -object StreamingContext { +object StreamingContext extends Logging { implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } + def getOrCreate( + checkpointPath: String, + creatingFunc: () => StreamingContext, + createOnCheckpointError: Boolean = false + ): StreamingContext = { + if (CheckpointReader.doesCheckpointExist(checkpointPath)) { + logInfo("Creating streaming context from checkpoint file") + new StreamingContext(checkpointPath) + } else { + logInfo("Creating new streaming context") + val ssc = creatingFunc() + ssc.checkpoint(checkpointPath) + ssc + } + } + protected[streaming] def createNewSparkContext( master: String, appName: String, |