aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/main
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-25 11:47:21 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-25 11:47:21 -0800
commitd29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 (patch)
tree8ec59422678c3c59da4eb08828d613595236fcfb /streaming/src/main
parent88875d9413ec7d64a88d40857ffcf97b5853a7f2 (diff)
downloadspark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.gz
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.tar.bz2
spark-d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0.zip
[SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java
The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu <shixiong@databricks.com> Closes #9922 from zsxwing/SPARK-11935.
Diffstat (limited to 'streaming/src/main')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala52
1 files changed, 43 insertions, 9 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index dfc569451d..994309ddd0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -26,6 +26,7 @@ import scala.language.existentials
import py4j.GatewayServer
+import org.apache.spark.SparkException
import org.apache.spark.api.java._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -40,6 +41,13 @@ import org.apache.spark.util.Utils
*/
private[python] trait PythonTransformFunction {
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+
+ /**
+ * Get the failure, if any, in the last call to `call`.
+ *
+ * @return the failure message if there was a failure, or `null` if there was no failure.
+ */
+ def getLastFailure: String
}
/**
@@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction {
private[python] trait PythonTransformFunctionSerializer {
def dumps(id: String): Array[Byte]
def loads(bytes: Array[Byte]): PythonTransformFunction
+
+ /**
+ * Get the failure, if any, in the last call to `dumps` or `loads`.
+ *
+ * @return the failure message if there was a failure, or `null` if there was no failure.
+ */
+ def getLastFailure: String
}
/**
@@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
- Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
- .map(_.rdd)
+ val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava
+ Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
}
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
- Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+ Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
}
// for function.Function2
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
- pfunc.call(time.milliseconds, rdds)
+ callPythonTransformFunction(time.milliseconds, rdds)
+ }
+
+ private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = {
+ val resultRDD = pfunc.call(time, rdds)
+ val failure = pfunc.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ resultRDD
}
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer {
/*
* Register a serializer from Python, should be called during initialization
*/
- def register(ser: PythonTransformFunctionSerializer): Unit = {
+ def register(ser: PythonTransformFunctionSerializer): Unit = synchronized {
serializer = ser
}
- def serialize(func: PythonTransformFunction): Array[Byte] = {
+ def serialize(func: PythonTransformFunction): Array[Byte] = synchronized {
require(serializer != null, "Serializer has not been registered!")
// get the id of PythonTransformFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
- serializer.dumps(id)
+ val results = serializer.dumps(id)
+ val failure = serializer.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ results
}
- def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+ def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized {
require(serializer != null, "Serializer has not been registered!")
- serializer.loads(bytes)
+ val pfunc = serializer.loads(bytes)
+ val failure = serializer.getLastFailure
+ if (failure != null) {
+ throw new SparkException("An exception was raised by Python:\n" + failure)
+ }
+ pfunc
}
}