aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala70
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala26
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala27
4 files changed, 89 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 5abfa467c0..bb5db54553 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
import org.apache.spark.Logging
-private[serializer] object SerializationDebugger extends Logging {
+private[spark] object SerializationDebugger extends Logging {
/**
* Improve the given NotSerializableException with the serialization path leading from the given
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 7bfae253c3..d8dc4e4101 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -102,6 +102,44 @@ object Checkpoint extends Logging {
Seq.empty
}
}
+
+ /** Serialize the checkpoint, or throw any exception that occurs */
+ def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = {
+ val compressionCodec = CompressionCodec.createCodec(conf)
+ val bos = new ByteArrayOutputStream()
+ val zos = compressionCodec.compressedOutputStream(bos)
+ val oos = new ObjectOutputStream(zos)
+ Utils.tryWithSafeFinally {
+ oos.writeObject(checkpoint)
+ } {
+ oos.close()
+ }
+ bos.toByteArray
+ }
+
+ /** Deserialize a checkpoint from the input stream, or throw any exception that occurs */
+ def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = {
+ val compressionCodec = CompressionCodec.createCodec(conf)
+ var ois: ObjectInputStreamWithLoader = null
+ Utils.tryWithSafeFinally {
+
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val zis = compressionCodec.compressedInputStream(inputStream)
+ ois = new ObjectInputStreamWithLoader(zis,
+ Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ cp.validate()
+ cp
+ } {
+ if (ois != null) {
+ ois.close()
+ }
+ }
+ }
}
@@ -189,17 +227,10 @@ class CheckpointWriter(
}
def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) {
- val bos = new ByteArrayOutputStream()
- val zos = compressionCodec.compressedOutputStream(bos)
- val oos = new ObjectOutputStream(zos)
- Utils.tryWithSafeFinally {
- oos.writeObject(checkpoint)
- } {
- oos.close()
- }
try {
+ val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
- checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater))
+ checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
@@ -264,25 +295,8 @@ object CheckpointReader extends Logging {
checkpointFiles.foreach(file => {
logInfo("Attempting to load checkpoint from file " + file)
try {
- var ois: ObjectInputStreamWithLoader = null
- var cp: Checkpoint = null
- Utils.tryWithSafeFinally {
- val fis = fs.open(file)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val zis = compressionCodec.compressedInputStream(fis)
- ois = new ObjectInputStreamWithLoader(zis,
- Thread.currentThread().getContextClassLoader)
- cp = ois.readObject.asInstanceOf[Checkpoint]
- } {
- if (ois != null) {
- ois.close()
- }
- }
- cp.validate()
+ val fis = fs.open(file)
+ val cp = Checkpoint.deserialize(fis, conf)
logInfo("Checkpoint successfully loaded from file " + file)
logInfo("Checkpoint was generated at time " + cp.checkpointTime)
return Some(cp)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index fe614c4be5..95063692e1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming
-import java.io.InputStream
+import java.io.{InputStream, NotSerializableException}
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
@@ -35,6 +35,7 @@ import org.apache.spark._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.input.FixedLengthBinaryInputFormat
import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark.serializer.SerializationDebugger
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContextState._
import org.apache.spark.streaming.dstream._
@@ -235,6 +236,10 @@ class StreamingContext private[streaming] (
}
}
+ private[streaming] def isCheckpointingEnabled: Boolean = {
+ checkpointDir != null
+ }
+
private[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
@@ -523,11 +528,26 @@ class StreamingContext private[streaming] (
assert(graph != null, "Graph is null")
graph.validate()
- assert(
- checkpointDir == null || checkpointDuration != null,
+ require(
+ !isCheckpointingEnabled || checkpointDuration != null,
"Checkpoint directory has been set, but the graph checkpointing interval has " +
"not been set. Please use StreamingContext.checkpoint() to set the interval."
)
+
+ // Verify whether the DStream checkpoint is serializable
+ if (isCheckpointingEnabled) {
+ val checkpoint = new Checkpoint(this, Time.apply(0))
+ try {
+ Checkpoint.serialize(checkpoint, conf)
+ } catch {
+ case e: NotSerializableException =>
+ throw new NotSerializableException(
+ "DStream checkpointing has been enabled but the DStreams with their functions " +
+ "are not serializable\nSerialization stack:\n" +
+ SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n")
+ )
+ }
+ }
}
/**
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 4b12affbb0..3a958bf3a3 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,21 +17,21 @@
package org.apache.spark.streaming
-import java.io.File
+import java.io.{File, NotSerializableException}
import java.util.concurrent.atomic.AtomicInteger
import org.apache.commons.io.FileUtils
-import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
-import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
+import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
@@ -132,6 +132,25 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}
+ test("start with non-seriazable DStream checkpoints") {
+ val checkpointDir = Utils.createTempDir()
+ ssc = new StreamingContext(conf, batchDuration)
+ ssc.checkpoint(checkpointDir.getAbsolutePath)
+ addInputStream(ssc).foreachRDD { rdd =>
+ // Refer to this.appName from inside closure so that this closure refers to
+ // the instance of StreamingContextSuite, and is therefore not serializable
+ rdd.count() + appName
+ }
+
+ // Test whether start() fails early when checkpointing is enabled
+ val exception = intercept[NotSerializableException] {
+ ssc.start()
+ }
+ assert(exception.getMessage().contains("DStreams with their functions are not serializable"))
+ assert(ssc.getState() !== StreamingContextState.ACTIVE)
+ assert(StreamingContext.getActive().isEmpty)
+ }
+
test("start multiple times") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register()