aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-25 11:47:21 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-25 11:47:21 -0800
commitd29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 (patch)
tree8ec59422678c3c59da4eb08828d613595236fcfb
parent88875d9413ec7d64a88d40857ffcf97b5853a7f2 (diff)
downloadspark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.gz
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.bz2
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.zip
[SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java
The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu <shixiong@databricks.com> Closes #9922 from zsxwing/SPARK-11935.
-rw-r--r--python/pyspark/streaming/tests.py82
-rw-r--r--python/pyspark/streaming/util.py29
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala52
3 files changed, 144 insertions, 19 deletions
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index a0e0267caf..d380d697bc 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -404,17 +404,69 @@ class BasicOperationTests(PySparkStreamingTestCase):
self._test_func(input, func, expected)
def test_failed_func(self):
+ # Test failure in
+ # TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
input = [self.sc.parallelize([d], 1) for d in range(4)]
input_stream = self.ssc.queueStream(input)
def failed_func(i):
- raise ValueError("failed")
+ raise ValueError("This is a special error")
input_stream.map(failed_func).pprint()
self.ssc.start()
try:
self.ssc.awaitTerminationOrTimeout(10)
except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
+ return
+
+ self.fail("a failed func should throw an error")
+
+ def test_failed_func2(self):
+ # Test failure in
+ # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time)
+ input = [self.sc.parallelize([d], 1) for d in range(4)]
+ input_stream1 = self.ssc.queueStream(input)
+ input_stream2 = self.ssc.queueStream(input)
+
+ def failed_func(rdd1, rdd2):
+ raise ValueError("This is a special error")
+
+ input_stream1.transformWith(failed_func, input_stream2, True).pprint()
+ self.ssc.start()
+ try:
+ self.ssc.awaitTerminationOrTimeout(10)
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
+ return
+
+ self.fail("a failed func should throw an error")
+
+ def test_failed_func_with_reseting_failure(self):
+ input = [self.sc.parallelize([d], 1) for d in range(4)]
+ input_stream = self.ssc.queueStream(input)
+
+ def failed_func(i):
+ if i == 1:
+ # Make it fail in the second batch
+ raise ValueError("This is a special error")
+ else:
+ return i
+
+ # We should be able to see the results of the 3rd and 4th batches even if the second batch
+ # fails
+ expected = [[0], [2], [3]]
+ self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3))
+ try:
+ self.ssc.awaitTerminationOrTimeout(10)
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue("This is a special error" in failure)
return
self.fail("a failed func should throw an error")
@@ -780,6 +832,34 @@ class CheckpointTests(unittest.TestCase):
if self.cpd is not None:
shutil.rmtree(self.cpd)
+ def test_transform_function_serializer_failure(self):
+ inputd = tempfile.mkdtemp()
+ self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure")
+
+ def setup():
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ sc = SparkContext(conf=conf)
+ ssc = StreamingContext(sc, 0.5)
+
+ # A function that cannot be serialized
+ def process(time, rdd):
+ sc.parallelize(range(1, 10))
+
+ ssc.textFileStream(inputd).foreachRDD(process)
+ return ssc
+
+ self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
+ try:
+ self.ssc.start()
+ except:
+ import traceback
+ failure = traceback.format_exc()
+ self.assertTrue(
+ "It appears that you are attempting to reference SparkContext" in failure)
+ return
+
+ self.fail("using SparkContext in process should fail because it's not Serializable")
+
def test_get_or_create_and_get_active_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index 767c732eb9..c7f02bca2a 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -38,12 +38,15 @@ class TransformFunction(object):
self.func = func
self.deserializers = deserializers
self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+ self.failure = None
def rdd_wrapper(self, func):
self._rdd_wrapper = func
return self
def call(self, milliseconds, jrdds):
+ # Clear the failure
+ self.failure = None
try:
if self.ctx is None:
self.ctx = SparkContext._active_spark_context
@@ -62,9 +65,11 @@ class TransformFunction(object):
r = self.func(t, *rdds)
if r:
return r._jrdd
- except Exception:
- traceback.print_exc()
- raise
+ except:
+ self.failure = traceback.format_exc()
+
+ def getLastFailure(self):
+ return self.failure
def __repr__(self):
return "TransformFunction(%s)" % self.func
@@ -89,22 +94,28 @@ class TransformFunctionSerializer(object):
self.serializer = serializer
self.gateway = gateway or self.ctx._gateway
self.gateway.jvm.PythonDStream.registerSerializer(self)
+ self.failure = None
def dumps(self, id):
+ # Clear the failure
+ self.failure = None
try:
func = self.gateway.gateway_property.pool[id]
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
- except Exception:
- traceback.print_exc()
- raise
+ except:
+ self.failure = traceback.format_exc()
def loads(self, data):
+ # Clear the failure
+ self.failure = None
try:
f, deserializers = self.serializer.loads(bytes(data))
return TransformFunction(self.ctx, f, *deserializers)
- except Exception:
- traceback.print_exc()
- raise
+ except:
+ self.failure = traceback.format_exc()
+
+ def getLastFailure(self):
+ return self.failure
def __repr__(self):
return "TransformFunctionSerializer(%s)" % self.serializer
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index dfc569451d..994309ddd0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -26,6 +26,7 @@ import scala.language.existentials
import py4j.GatewayServer
+import org.apache.spark.SparkException
import org.apache.spark.api.java._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -40,6 +41,13 @@ import org.apache.spark.util.Utils
*/
private[python] trait PythonTransformFunction {
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+
+ /**
+ * Get the failure, if any, in the last call to `call`.
+ *
+ * @return the failure message if there was a failure, or `null` if there was no failure.
+ */
+ def getLastFailure: String
}
/**
@@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction {
private[python] trait PythonTransformFunctionSerializer {
def dumps(id: String): Array[Byte]
def loads(bytes: Array[Byte]): PythonTransformFunction
+
+ /**
+ * Get the failure, if any, in the last call to `dumps` or `loads`.
+ *
+ * @return the failure message if there was a failure, or `null` if there was no failure.
+ */
+ def getLastFailure: String
}
/**
@@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
- Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
- .map(_.rdd)
+ val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava
+ Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
}
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
- Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+ Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
}
// for function.Function2
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
- pfunc.call(time.milliseconds, rdds)
+ callPythonTransformFunction(time.milliseconds, rdds)
+ }
+
+ private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = {
+ val resultRDD = pfunc.call(time, rdds)
+ val failure = pfunc.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ resultRDD
}
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer {
/*
* Register a serializer from Python, should be called during initialization
*/
- def register(ser: PythonTransformFunctionSerializer): Unit = {
+ def register(ser: PythonTransformFunctionSerializer): Unit = synchronized {
serializer = ser
}
- def serialize(func: PythonTransformFunction): Array[Byte] = {
+ def serialize(func: PythonTransformFunction): Array[Byte] = synchronized {
require(serializer != null, "Serializer has not been registered!")
// get the id of PythonTransformFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
- serializer.dumps(id)
+ val results = serializer.dumps(id)
+ val failure = serializer.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ results
}
- def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+ def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized {
require(serializer != null, "Serializer has not been registered!")
- serializer.loads(bytes)
+ val pfunc = serializer.loads(bytes)
+ val failure = serializer.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ pfunc
}
}