aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/streaming/tests.py
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-25 11:47:21 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-25 11:47:21 -0800
commitd29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 (patch)
tree8ec59422678c3c59da4eb08828d613595236fcfb /python/pyspark/streaming/tests.py
parent88875d9413ec7d64a88d40857ffcf97b5853a7f2 (diff)
downloadspark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.gz
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.bz2
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.zip
[SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java
The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu <shixiong@databricks.com> Closes #9922 from zsxwing/SPARK-11935.
Diffstat (limited to 'python/pyspark/streaming/tests.py')
-rw-r--r--python/pyspark/streaming/tests.py82
1 files changed, 81 insertions, 1 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index a0e0267caf..d380d697bc 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -404,17 +404,69 @@ class BasicOperationTests(PySparkStreamingTestCase):
self._test_func(input, func, expected)
def test_failed_func(self):
+ # Test failure in
+ # TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
input = [self.sc.parallelize([d], 1) for d in range(4)]
input_stream = self.ssc.queueStream(input)
def failed_func(i):
- raise ValueError("failed")
+ raise ValueError("This is a special error")
input_stream.map(failed_func).pprint()
self.ssc.start()
try:
self.ssc.awaitTerminationOrTimeout(10)
except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
+ return
+
+ self.fail("a failed func should throw an error")
+
+ def test_failed_func2(self):
+ # Test failure in
+ # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time)
+ input = [self.sc.parallelize([d], 1) for d in range(4)]
+ input_stream1 = self.ssc.queueStream(input)
+ input_stream2 = self.ssc.queueStream(input)
+
+ def failed_func(rdd1, rdd2):
+ raise ValueError("This is a special error")
+
+ input_stream1.transformWith(failed_func, input_stream2, True).pprint()
+ self.ssc.start()
+ try:
+ self.ssc.awaitTerminationOrTimeout(10)
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
+ return
+
+ self.fail("a failed func should throw an error")
+
+ def test_failed_func_with_reseting_failure(self):
+ input = [self.sc.parallelize([d], 1) for d in range(4)]
+ input_stream = self.ssc.queueStream(input)
+
+ def failed_func(i):
+ if i == 1:
+ # Make it fail in the second batch
+ raise ValueError("This is a special error")
+ else:
+ return i
+
+ # We should be able to see the results of the 3rd and 4th batches even if the second batch
+ # fails
+ expected = [[0], [2], [3]]
+ self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3))
+ try:
+ self.ssc.awaitTerminationOrTimeout(10)
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
return
self.fail("a failed func should throw an error")
@@ -780,6 +832,34 @@ class CheckpointTests(unittest.TestCase):
if self.cpd is not None:
shutil.rmtree(self.cpd)
+ def test_transform_function_serializer_failure(self):
+ inputd = tempfile.mkdtemp()
+ self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure")
+
+ def setup():
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ sc = SparkContext(conf=conf)
+ ssc = StreamingContext(sc, 0.5)
+
+ # A function that cannot be serialized
+ def process(time, rdd):
+ sc.parallelize(range(1, 10))
+
+ ssc.textFileStream(inputd).foreachRDD(process)
+ return ssc
+
+ self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
+ try:
+ self.ssc.start()
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue(
+ "It appears that you are attempting to reference SparkContext" in failure)
+ return
+
+ self.fail("using SparkContext in process should fail because it's not Serializable")
+
def test_get_or_create_and_get_active_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"