aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-04-23 11:29:34 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-04-23 11:29:34 -0700
commit534f2a43625fbf1a3a65d09550a19875cd1dce43 (patch)
tree8f77495057454cd0ee360e00841f6df84afcb3b5 /streaming
parentcc48e6387abdd909921cb58e0588cdf226556bcd (diff)
downloadspark-534f2a43625fbf1a3a65d09550a19875cd1dce43.tar.gz
spark-534f2a43625fbf1a3a65d09550a19875cd1dce43.tar.bz2
spark-534f2a43625fbf1a3a65d09550a19875cd1dce43.zip
[SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext
Currently if you want to create a StreamingContext from checkpoint information, the system will create a new SparkContext. This prevent StreamingContext to be recreated from checkpoints in managed environments where SparkContext is precreated. The solution in this PR: Introduce the following methods on StreamingContext 1. `new StreamingContext(checkpointDirectory, sparkContext)` Recreate StreamingContext from checkpoint using the provided SparkContext 2. `StreamingContext.getOrCreate(checkpointDirectory, sparkContext, createFunction: SparkContext => StreamingContext)` If checkpoint file exists, then recreate StreamingContext using the provided SparkContext (that is, call 1.), else create StreamingContext using the provided createFunction TODO: the corresponding Java and Python API has to be added as well. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #5428 from tdas/SPARK-6752 and squashes the following commits: 94db63c [Tathagata Das] Fix long line. 524f519 [Tathagata Das] Many changes based on PR comments. eabd092 [Tathagata Das] Added Function0, Java API and unit tests for StreamingContext.getOrCreate 36a7823 [Tathagata Das] Minor changes. 204814e [Tathagata Das] Added StreamingContext.getOrCreate with existing SparkContext
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala26
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala85
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala119
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java145
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala3
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala159
6 files changed, 476 insertions, 61 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 0a50485118..7bfae253c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -77,7 +77,8 @@ object Checkpoint extends Logging {
}
/** Get checkpoint files present in the give directory, ordered by oldest-first */
- def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = {
+ def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = {
+
def sortFunc(path1: Path, path2: Path): Boolean = {
val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
@@ -85,6 +86,7 @@ object Checkpoint extends Logging {
}
val path = new Path(checkpointDir)
+ val fs = fsOption.getOrElse(path.getFileSystem(new Configuration()))
if (fs.exists(path)) {
val statuses = fs.listStatus(path)
if (statuses != null) {
@@ -160,7 +162,7 @@ class CheckpointWriter(
}
// Delete old checkpoint files
- val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
+ val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs))
if (allCheckpointFiles.size > 10) {
allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
logInfo("Deleting " + file)
@@ -234,15 +236,24 @@ class CheckpointWriter(
private[streaming]
object CheckpointReader extends Logging {
- def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] =
- {
+ /**
+ * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint
+ * files, then return None, else try to return the latest valid checkpoint object. If no
+ * checkpoint files could be read correctly, then return None (if ignoreReadError = true),
+ * or throw exception (if ignoreReadError = false).
+ */
+ def read(
+ checkpointDir: String,
+ conf: SparkConf,
+ hadoopConf: Configuration,
+ ignoreReadError: Boolean = false): Option[Checkpoint] = {
val checkpointPath = new Path(checkpointDir)
// TODO(rxin): Why is this a def?!
def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf)
// Try to find the checkpoint files
- val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse
+ val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse
if (checkpointFiles.isEmpty) {
return None
}
@@ -282,7 +293,10 @@ object CheckpointReader extends Logging {
})
// If none of checkpoint files could be read, then throw exception
- throw new SparkException("Failed to read checkpoint from directory " + checkpointPath)
+ if (!ignoreReadError) {
+ throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath")
+ }
+ None
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index f57f295874..90c8b47aeb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -107,6 +107,19 @@ class StreamingContext private[streaming] (
*/
def this(path: String) = this(path, new Configuration)
+ /**
+ * Recreate a StreamingContext from a checkpoint file using an existing SparkContext.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ * @param sparkContext Existing SparkContext
+ */
+ def this(path: String, sparkContext: SparkContext) = {
+ this(
+ sparkContext,
+ CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get,
+ null)
+ }
+
+
if (sc_ == null && cp_ == null) {
throw new Exception("Spark Streaming cannot be initialized with " +
"both SparkContext and checkpoint as null")
@@ -115,10 +128,12 @@ class StreamingContext private[streaming] (
private[streaming] val isCheckpointPresent = (cp_ != null)
private[streaming] val sc: SparkContext = {
- if (isCheckpointPresent) {
+ if (sc_ != null) {
+ sc_
+ } else if (isCheckpointPresent) {
new SparkContext(cp_.createSparkConf())
} else {
- sc_
+ throw new SparkException("Cannot create StreamingContext without a SparkContext")
}
}
@@ -129,7 +144,7 @@ class StreamingContext private[streaming] (
private[streaming] val conf = sc.conf
- private[streaming] val env = SparkEnv.get
+ private[streaming] val env = sc.env
private[streaming] val graph: DStreamGraph = {
if (isCheckpointPresent) {
@@ -174,7 +189,9 @@ class StreamingContext private[streaming] (
/** Register streaming source to metrics system */
private val streamingSource = new StreamingSource(this)
- SparkEnv.get.metricsSystem.registerSource(streamingSource)
+ assert(env != null)
+ assert(env.metricsSystem != null)
+ env.metricsSystem.registerSource(streamingSource)
/** Enumeration to identify current state of the StreamingContext */
private[streaming] object StreamingContextState extends Enumeration {
@@ -621,19 +638,59 @@ object StreamingContext extends Logging {
hadoopConf: Configuration = new Configuration(),
createOnError: Boolean = false
): StreamingContext = {
- val checkpointOption = try {
- CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf)
- } catch {
- case e: Exception =>
- if (createOnError) {
- None
- } else {
- throw e
- }
- }
+ val checkpointOption = CheckpointReader.read(
+ checkpointPath, new SparkConf(), hadoopConf, createOnError)
checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc())
}
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
+ * that the SparkConf configuration in the checkpoint data will not be restored as the
+ * SparkContext has already been created.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new StreamingContext using the given SparkContext
+ * @param sparkContext SparkContext using which the StreamingContext will be created
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: SparkContext => StreamingContext,
+ sparkContext: SparkContext
+ ): StreamingContext = {
+ getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
+ * that the SparkConf configuration in the checkpoint data will not be restored as the
+ * SparkContext has already been created.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new StreamingContext using the given SparkContext
+ * @param sparkContext SparkContext using which the StreamingContext will be created
+ * @param createOnError Whether to create a new StreamingContext if there is an
+ * error in reading checkpoint data. By default, an exception will be
+ * thrown on error.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: SparkContext => StreamingContext,
+ sparkContext: SparkContext,
+ createOnError: Boolean
+ ): StreamingContext = {
+ val checkpointOption = CheckpointReader.read(
+ checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError)
+ checkpointOption.map(new StreamingContext(sparkContext, _, null))
+ .getOrElse(creatingFunc(sparkContext))
+ }
+
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 4095a7cc84..572d7d8e87 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function0 => JFunction0}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.scheduler.StreamingListener
-import org.apache.hadoop.conf.Configuration
-import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream}
+import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
+import org.apache.hadoop.conf.Configuration
/**
* A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main
@@ -655,6 +656,7 @@ object JavaStreamingContext {
* @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
* @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
*/
+ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0")
def getOrCreate(
checkpointPath: String,
factory: JavaStreamingContextFactory
@@ -676,6 +678,7 @@ object JavaStreamingContext {
* @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
* file system
*/
+ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
checkpointPath: String,
hadoopConf: Configuration,
@@ -700,6 +703,7 @@ object JavaStreamingContext {
* @param createOnError Whether to create a new JavaStreamingContext if there is an
* error in reading checkpoint data.
*/
+ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
checkpointPath: String,
hadoopConf: Configuration,
@@ -713,6 +717,117 @@ object JavaStreamingContext {
}
/**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ * @param creatingFunc Function to create a new JavaStreamingContext
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: JFunction0[JavaStreamingContext]
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ creatingFunc.call().ssc
+ })
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: JFunction0[JavaStreamingContext],
+ hadoopConf: Configuration
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ creatingFunc.call().ssc
+ }, hadoopConf)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ * @param createOnError Whether to create a new JavaStreamingContext if there is an
+ * error in reading checkpoint data.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: JFunction0[JavaStreamingContext],
+ hadoopConf: Configuration,
+ createOnError: Boolean
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ creatingFunc.call().ssc
+ }, hadoopConf, createOnError)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new JavaStreamingContext
+ * @param sparkContext SparkContext using which the StreamingContext will be created
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
+ sparkContext: JavaSparkContext
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
+ creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
+ }, sparkContext.sc)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new JavaStreamingContext
+ * @param sparkContext SparkContext using which the StreamingContext will be created
+ * @param createOnError Whether to create a new JavaStreamingContext if there is an
+ * error in reading checkpoint data.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
+ sparkContext: JavaSparkContext,
+ createOnError: Boolean
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
+ creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
+ }, sparkContext.sc, createOnError)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
*/
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 90340753a4..cb2e8380b4 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -22,10 +22,12 @@ import java.lang.Iterable;
import java.nio.charset.Charset;
import java.util.*;
+import org.apache.commons.lang.mutable.MutableBoolean;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+
import scala.Tuple2;
import org.junit.Assert;
@@ -45,6 +47,7 @@ import org.apache.spark.api.java.function.*;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.api.java.*;
import org.apache.spark.util.Utils;
+import org.apache.spark.SparkConf;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -929,7 +932,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
public Tuple2<Integer, String> call(Tuple2<String, Integer> in) throws Exception {
return in.swap();
}
- });
+ });
JavaTestUtils.attachTestOutputStream(reversed);
List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -987,12 +990,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
JavaDStream<Integer> reversed = pairStream.map(
- new Function<Tuple2<String, Integer>, Integer>() {
- @Override
- public Integer call(Tuple2<String, Integer> in) throws Exception {
- return in._2();
- }
- });
+ new Function<Tuple2<String, Integer>, Integer>() {
+ @Override
+ public Integer call(Tuple2<String, Integer> in) throws Exception {
+ return in._2();
+ }
+ });
JavaTestUtils.attachTestOutputStream(reversed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -1123,7 +1126,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(
new Function<Integer, Integer>() {
- @Override
+ @Override
public Integer call(Integer i) throws Exception {
return i;
}
@@ -1144,14 +1147,14 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
Arrays.asList("hello"));
List<List<Tuple2<String, Long>>> expected = Arrays.asList(
- Arrays.asList(
- new Tuple2<String, Long>("hello", 1L),
- new Tuple2<String, Long>("world", 1L)),
- Arrays.asList(
- new Tuple2<String, Long>("hello", 1L),
- new Tuple2<String, Long>("moon", 1L)),
- Arrays.asList(
- new Tuple2<String, Long>("hello", 1L)));
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L),
+ new Tuple2<String, Long>("world", 1L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L),
+ new Tuple2<String, Long>("moon", 1L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L)));
JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Long> counted = stream.countByValue();
@@ -1249,17 +1252,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
- @Override
- public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
- int out = 0;
- if (state.isPresent()) {
- out = out + state.get();
- }
- for (Integer v: values) {
- out = out + v;
+ @Override
+ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+ int out = 0;
+ if (state.isPresent()) {
+ out = out + state.get();
+ }
+ for (Integer v : values) {
+ out = out + v;
+ }
+ return Optional.of(out);
}
- return Optional.of(out);
- }
});
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -1292,17 +1295,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
- @Override
- public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
- int out = 0;
- if (state.isPresent()) {
- out = out + state.get();
- }
- for (Integer v: values) {
- out = out + v;
+ @Override
+ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+ int out = 0;
+ if (state.isPresent()) {
+ out = out + state.get();
+ }
+ for (Integer v : values) {
+ out = out + v;
+ }
+ return Optional.of(out);
}
- return Optional.of(out);
- }
}, new HashPartitioner(1), initialRDD);
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -1328,7 +1331,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaPairDStream<String, Integer> reduceWindowed =
pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(),
- new Duration(2000), new Duration(1000));
+ new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(reduceWindowed);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -1707,6 +1710,74 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
Utils.deleteRecursively(tempDir);
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testContextGetOrCreate() throws InterruptedException {
+
+ final SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("test")
+ .set("newContext", "true");
+
+ File emptyDir = Files.createTempDir();
+ emptyDir.deleteOnExit();
+ StreamingContextSuite contextSuite = new StreamingContextSuite();
+ String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint();
+ String checkpointDir = contextSuite.createValidCheckpoint();
+
+ // Function to create JavaStreamingContext without any output operations
+ // (used to detect the new context)
+ final MutableBoolean newContextCreated = new MutableBoolean(false);
+ Function0<JavaStreamingContext> creatingFunc = new Function0<JavaStreamingContext>() {
+ public JavaStreamingContext call() {
+ newContextCreated.setValue(true);
+ return new JavaStreamingContext(conf, Seconds.apply(1));
+ }
+ };
+
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc);
+ Assert.assertTrue("new context not created", newContextCreated.isTrue());
+ ssc.stop();
+
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc,
+ new org.apache.hadoop.conf.Configuration(), true);
+ Assert.assertTrue("new context not created", newContextCreated.isTrue());
+ ssc.stop();
+
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc,
+ new org.apache.hadoop.conf.Configuration());
+ Assert.assertTrue("old context not recovered", newContextCreated.isFalse());
+ ssc.stop();
+
+ // Function to create JavaStreamingContext using existing JavaSparkContext
+ // without any output operations (used to detect the new context)
+ Function<JavaSparkContext, JavaStreamingContext> creatingFunc2 =
+ new Function<JavaSparkContext, JavaStreamingContext>() {
+ public JavaStreamingContext call(JavaSparkContext context) {
+ newContextCreated.setValue(true);
+ return new JavaStreamingContext(context, Seconds.apply(1));
+ }
+ };
+
+ JavaSparkContext sc = new JavaSparkContext(conf);
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc);
+ Assert.assertTrue("new context not created", newContextCreated.isTrue());
+ ssc.stop(false);
+
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true);
+ Assert.assertTrue("new context not created", newContextCreated.isTrue());
+ ssc.stop(false);
+
+ newContextCreated.setValue(false);
+ ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc);
+ Assert.assertTrue("old context not recovered", newContextCreated.isFalse());
+ ssc.stop();
+ }
/* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD
@SuppressWarnings("unchecked")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 54c30440a6..6b0a3f91d4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase {
assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3)
}
// Wait for a checkpoint to be written
- val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration)
eventually(eventuallyTimeout) {
- assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6)
+ assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6)
}
ssc.stop()
// Check that we shut down while the third batch was being processed
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 58353a5f97..4f193322ad 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,8 +17,10 @@
package org.apache.spark.streaming
+import java.io.File
import java.util.concurrent.atomic.AtomicInteger
+import org.apache.commons.io.FileUtils
import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
@@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}
+ test("getOrCreate") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+
+ // Function to create StreamingContext that has a config to identify it to be new context
+ var newContextCreated = false
+ def creatingFunction(): StreamingContext = {
+ newContextCreated = true
+ new StreamingContext(conf, batchDuration)
+ }
+
+ // Call ssc.stop after a body of code
+ def testGetOrCreate(body: => Unit): Unit = {
+ newContextCreated = false
+ try {
+ body
+ } finally {
+ if (ssc != null) {
+ ssc.stop()
+ }
+ ssc = null
+ }
+ }
+
+ val emptyPath = Utils.createTempDir().getAbsolutePath()
+
+ // getOrCreate should create new context with empty path
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _)
+ assert(ssc != null, "no context created")
+ assert(newContextCreated, "new context not created")
+ }
+
+ val corrutedCheckpointPath = createCorruptedCheckpoint()
+
+ // getOrCreate should throw exception with fake checkpoint file and createOnError = false
+ intercept[Exception] {
+ ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _)
+ }
+
+ // getOrCreate should throw exception with fake checkpoint file
+ intercept[Exception] {
+ ssc = StreamingContext.getOrCreate(
+ corrutedCheckpointPath, creatingFunction _, createOnError = false)
+ }
+
+ // getOrCreate should create new context with fake checkpoint file and createOnError = true
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(
+ corrutedCheckpointPath, creatingFunction _, createOnError = true)
+ assert(ssc != null, "no context created")
+ assert(newContextCreated, "new context not created")
+ }
+
+ val checkpointPath = createValidCheckpoint()
+
+ // getOrCreate should recover context with checkpoint path, and recover old configuration
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _)
+ assert(ssc != null, "no context created")
+ assert(!newContextCreated, "old context not recovered")
+ assert(ssc.conf.get("someKey") === "someValue")
+ }
+ }
+
+ test("getOrCreate with existing SparkContext") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ sc = new SparkContext(conf)
+
+ // Function to create StreamingContext that has a config to identify it to be new context
+ var newContextCreated = false
+ def creatingFunction(sparkContext: SparkContext): StreamingContext = {
+ newContextCreated = true
+ new StreamingContext(sparkContext, batchDuration)
+ }
+
+ // Call ssc.stop(stopSparkContext = false) after a body of cody
+ def testGetOrCreate(body: => Unit): Unit = {
+ newContextCreated = false
+ try {
+ body
+ } finally {
+ if (ssc != null) {
+ ssc.stop(stopSparkContext = false)
+ }
+ ssc = null
+ }
+ }
+
+ val emptyPath = Utils.createTempDir().getAbsolutePath()
+
+ // getOrCreate should create new context with empty path
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true)
+ assert(ssc != null, "no context created")
+ assert(newContextCreated, "new context not created")
+ assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+ }
+
+ val corrutedCheckpointPath = createCorruptedCheckpoint()
+
+ // getOrCreate should throw exception with fake checkpoint file and createOnError = false
+ intercept[Exception] {
+ ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc)
+ }
+
+ // getOrCreate should throw exception with fake checkpoint file
+ intercept[Exception] {
+ ssc = StreamingContext.getOrCreate(
+ corrutedCheckpointPath, creatingFunction _, sc, createOnError = false)
+ }
+
+ // getOrCreate should create new context with fake checkpoint file and createOnError = true
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(
+ corrutedCheckpointPath, creatingFunction _, sc, createOnError = true)
+ assert(ssc != null, "no context created")
+ assert(newContextCreated, "new context not created")
+ assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+ }
+
+ val checkpointPath = createValidCheckpoint()
+
+ // StreamingContext.getOrCreate should recover context with checkpoint path
+ testGetOrCreate {
+ ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc)
+ assert(ssc != null, "no context created")
+ assert(!newContextCreated, "old context not recovered")
+ assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+ assert(!ssc.conf.contains("someKey"),
+ "recovered StreamingContext unexpectedly has old config")
+ }
+ }
+
test("DStream and generated RDD creation sites") {
testPackage.test()
}
@@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
val inputStream = new TestInputStream(s, input, 1)
inputStream
}
+
+ def createValidCheckpoint(): String = {
+ val testDirectory = Utils.createTempDir().getAbsolutePath()
+ val checkpointDirectory = Utils.createTempDir().getAbsolutePath()
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ conf.set("someKey", "someValue")
+ ssc = new StreamingContext(conf, batchDuration)
+ ssc.checkpoint(checkpointDirectory)
+ ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() }
+ ssc.start()
+ eventually(timeout(10000 millis)) {
+ assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1)
+ }
+ ssc.stop()
+ checkpointDirectory
+ }
+
+ def createCorruptedCheckpoint(): String = {
+ val checkpointDirectory = Utils.createTempDir().getAbsolutePath()
+ val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000))
+ FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla")
+ assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty)
+ checkpointDirectory
+ }
}
class TestException(msg: String) extends Exception(msg)