aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2013-01-31 18:02:28 -0800
committerPatrick Wendell <pwendell@gmail.com>2013-01-31 18:06:11 -0800
commit3446d5c8d6b385106ac85e46320d92faa8efb4e6 (patch)
tree220b114399de112adbebb774ac4bd456deb87040
parent55327a283e962652a126d3f8ac7e9a19c76f1f19 (diff)
downloadspark-3446d5c8d6b385106ac85e46320d92faa8efb4e6.tar.gz
spark-3446d5c8d6b385106ac85e46320d92faa8efb4e6.tar.bz2
spark-3446d5c8d6b385106ac85e46320d92faa8efb4e6.zip
SPARK-673: Capture and re-throw Python exceptions
This patch alters the Python <-> executor protocol to pass on exception data when they occur in user Python code.
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala40
-rw-r--r--python/pyspark/worker.py10
2 files changed, 34 insertions, 16 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f43a152ca7..6b9ef62529 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = {
try {
- val length = stream.readInt()
- if (length != -1) {
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
- } else {
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
- val update = new Array[Byte](len2)
- stream.readFully(update)
- accumulator += Collections.singletonList(update)
+ stream.readInt() match {
+ case length if length > 0 => {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
}
- new Array[Byte](0)
+ case -2 => {
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ }
+ case -1 => {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
}
} catch {
case eof: EOFException => {
@@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d33d6dd15f..9622e0cfe4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2,6 +2,7 @@
Worker that receives input from Piped RDD.
"""
import sys
+import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
@@ -40,8 +41,13 @@ def main():
else:
dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin)
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ try:
+ for obj in func(split_index, iterator):
+ write_with_length(dumps(obj), old_stdout)
+ except Exception as e:
+ write_int(-2, old_stdout)
+ write_with_length(traceback.format_exc(), old_stdout)
+ sys.exit(-1)
# Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout)
for aid, accum in _accumulatorRegistry.items():