aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/utils.py')
-rw-r--r--python/pyspark/sql/utils.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
new file mode 100644
index 0000000000..8096802e73
--- /dev/null
+++ b/python/pyspark/sql/utils.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j
+
+
+class AnalysisException(Exception):
+ """
+ Failed to analyze a SQL query plan.
+ """
+
+
+def capture_sql_exception(f):
+ def deco(*a, **kw):
+ try:
+ return f(*a, **kw)
+ except py4j.protocol.Py4JJavaError as e:
+ cls, msg = e.java_exception.toString().split(': ', 1)
+ if cls == 'org.apache.spark.sql.AnalysisException':
+ raise AnalysisException(msg)
+ raise
+ return deco
+
+
+def install_exception_handler():
+ """
+ Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
+
+ When calling Java API, it will call `get_return_value` to parse the returned object.
+ If any exception happened in JVM, the result will be Java exception object, it raise
+ py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
+ could capture the Java exception and throw a Python one (with the same error message).
+
+ It's idempotent, could be called multiple times.
+ """
+ original = py4j.protocol.get_return_value
+ # The original `get_return_value` is not patched, it's idempotent.
+ patched = capture_sql_exception(original)
+ # only patch the one used in in py4j.java_gateway (call Java API)
+ py4j.java_gateway.get_return_value = patched