aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala11
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala66
5 files changed, 83 insertions, 9 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
index f396c34758..4eb92dd8b1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
@@ -17,9 +17,10 @@
package org.apache.spark.streaming.dstream
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Time, StreamingContext}
-import scala.reflect.ClassTag
/**
* An input stream that always returns the same RDD on each timestep. Useful for testing.
@@ -27,6 +28,9 @@ import scala.reflect.ClassTag
class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T])
extends InputDStream[T](ssc_) {
+ require(rdd != null,
+ "parameter rdd null is illegal, which will lead to NPE in the following transformation")
+
override def start() {}
override def stop() {}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
index a2685046e0..cd07364637 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
@@ -62,7 +62,7 @@ class QueueInputDStream[T: ClassTag](
} else if (defaultRDD != null) {
Some(defaultRDD)
} else {
- None
+ Some(ssc.sparkContext.emptyRDD)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index ab01f47d5c..5eabdf63dc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream
import scala.reflect.ClassTag
import org.apache.spark.SparkException
-import org.apache.spark.rdd.{PairRDDFunctions, RDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}
private[streaming]
@@ -39,7 +39,10 @@ class TransformedDStream[U: ClassTag] (
override def slideDuration: Duration = parents.head.slideDuration
override def compute(validTime: Time): Option[RDD[U]] = {
- val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq
+ val parentRDDs = parents.map { parent => parent.getOrCompute(validTime).getOrElse(
+ // Guard out against parent DStream that return None instead of Some(rdd) to avoid NPE
+ throw new SparkException(s"Couldn't generate RDD from parent at time $validTime"))
+ }
val transformedRDD = transformFunc(parentRDDs, validTime)
if (transformedRDD == null) {
throw new SparkException("Transform function must not return null. " +
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index 9405dbaa12..d73ffdfd84 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -17,13 +17,14 @@
package org.apache.spark.streaming.dstream
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkException
import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
private[streaming]
class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
extends DStream[T](parents.head.ssc) {
@@ -41,8 +42,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
val rdds = new ArrayBuffer[RDD[T]]()
parents.map(_.getOrCompute(validTime)).foreach {
case Some(rdd) => rdds += rdd
- case None => throw new Exception("Could not generate RDD from a parent for unifying at time "
- + validTime)
+ case None => throw new SparkException("Could not generate RDD from a parent for unifying at" +
+ s" time $validTime")
}
if (rdds.size > 0) {
Some(new UnionRDD(ssc.sc, rdds))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 9988f410f0..9d296c6d3e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -191,6 +191,20 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}
+ test("union with input stream return None") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204, null)
+ val output = Seq(1 to 8, 101 to 108, 201 to 208)
+ intercept[SparkException] {
+ testOperation(
+ input,
+ (s: DStream[Int]) => s.union(s.map(_ + 4)),
+ output,
+ input.length,
+ false
+ )
+ }
+ }
+
test("StreamingContext.union") {
val input = Seq(1 to 4, 101 to 104, 201 to 204)
val output = Seq(1 to 12, 101 to 112, 201 to 212)
@@ -224,6 +238,19 @@ class BasicOperationsSuite extends TestSuiteBase {
}
}
+ test("transform with input stream return None") {
+ val input = Seq(1 to 4, 5 to 8, null)
+ intercept[SparkException] {
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)),
+ input.filterNot(_ == null).map(_.map(_.toString)),
+ input.length,
+ false
+ )
+ }
+ }
+
test("transformWith") {
val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
@@ -244,6 +271,27 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData1, inputData2, operation, outputData, true)
}
+ test("transformWith with input stream return None") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), null )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), null )
+ val outputData = Seq(
+ Seq("a", "b", "a", "b"),
+ Seq("a", "b", "", ""),
+ Seq("")
+ )
+
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.transformWith( // RDD.join in transform
+ s2,
+ (rdd1: RDD[String], rdd2: RDD[String]) => rdd1.union(rdd2)
+ )
+ }
+
+ intercept[SparkException] {
+ testOperation(inputData1, inputData2, operation, outputData, inputData1.length, true)
+ }
+ }
+
test("StreamingContext.transform") {
val input = Seq(1 to 4, 101 to 104, 201 to 204)
val output = Seq(1 to 12, 101 to 112, 201 to 212)
@@ -260,6 +308,24 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(input, operation, output)
}
+ test("StreamingContext.transform with input stream return None") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204, null)
+ val output = Seq(1 to 12, 101 to 112, 201 to 212)
+
+ // transform over 3 DStreams by doing union of the 3 RDDs
+ val operation = (s: DStream[Int]) => {
+ s.context.transform(
+ Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams
+ (rdds: Seq[RDD[_]], time: Time) =>
+ rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs
+ )
+ }
+
+ intercept[SparkException] {
+ testOperation(input, operation, output, input.length, false)
+ }
+ }
+
test("cogroup") {
val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )