aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-30 16:17:46 -0700
committerDavies Liu <davies@databricks.com>2015-06-30 16:17:46 -0700
commit58ee2a2e47948a895e557fbcabbeadb31f0a1022 (patch)
tree86d79553be2fd7df6b7afe2b226bb40ede319725 /python/pyspark
parentd2495f7cc7d7caaa50d122d2969ddb693e6ecebd (diff)
downloadspark-58ee2a2e47948a895e557fbcabbeadb31f0a1022.tar.gz
spark-58ee2a2e47948a895e557fbcabbeadb31f0a1022.tar.bz2
spark-58ee2a2e47948a895e557fbcabbeadb31f0a1022.zip
[SPARK-8738] [SQL] [PYSPARK] capture SQL AnalysisException in Python API
Capture the AnalysisException in SQL, hide the long java stack trace, only show the error message. cc rxin Author: Davies Liu <davies@databricks.com> Closes #7135 from davies/ananylis and squashes the following commits: dad7ae7 [Davies Liu] add comment ec0c0e8 [Davies Liu] Update utils.py cdd7edd [Davies Liu] add doc 7b044c2 [Davies Liu] fix python 3 f84d3bd [Davies Liu] capture SQL AnalysisException in Python API
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/rdd.py3
-rw-r--r--python/pyspark/sql/context.py2
-rw-r--r--python/pyspark/sql/tests.py7
-rw-r--r--python/pyspark/sql/utils.py54
4 files changed, 65 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb20bc8b54..79dafb0a4e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -126,11 +126,12 @@ def _load_from_socket(port, serializer):
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
+ sock = socket.socket(af, socktype, proto)
try:
- sock = socket.socket(af, socktype, proto)
sock.settimeout(3)
sock.connect(sa)
except socket.error:
+ sock.close()
sock = None
continue
break
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 4dda3b430c..4bf232111c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -33,6 +33,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.utils import install_exception_handler
try:
import pandas
@@ -96,6 +97,7 @@ class SQLContext(object):
self._jvm = self._sc._jvm
self._scala_SQLContext = sqlContext
_monkey_patch_RDD(self)
+ install_exception_handler()
@property
def _ssql_ctx(self):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 34f397d0ff..5af2ce09bc 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -46,6 +46,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.window import Window
+from pyspark.sql.utils import AnalysisException
class UTC(datetime.tzinfo):
@@ -847,6 +848,12 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)
+ def test_capture_analysis_exception(self):
+ self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
+ self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
+ # RuntimeException should not be captured
+ self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
+
class HiveContextSQLTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
new file mode 100644
index 0000000000..8096802e73
--- /dev/null
+++ b/python/pyspark/sql/utils.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j
+
+
+class AnalysisException(Exception):
+ """
+ Failed to analyze a SQL query plan.
+ """
+
+
+def capture_sql_exception(f):
+ def deco(*a, **kw):
+ try:
+ return f(*a, **kw)
+ except py4j.protocol.Py4JJavaError as e:
+ cls, msg = e.java_exception.toString().split(': ', 1)
+ if cls == 'org.apache.spark.sql.AnalysisException':
+ raise AnalysisException(msg)
+ raise
+ return deco
+
+
+def install_exception_handler():
+ """
+ Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
+
+ When calling Java API, it will call `get_return_value` to parse the returned object.
+ If any exception happened in JVM, the result will be Java exception object, it raise
+ py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
+ could capture the Java exception and throw a Python one (with the same error message).
+
+ It's idempotent, could be called multiple times.
+ """
+ original = py4j.protocol.get_return_value
+ # The original `get_return_value` is not patched, it's idempotent.
+ patched = capture_sql_exception(original)
+ # only patch the one used in in py4j.java_gateway (call Java API)
+ py4j.java_gateway.get_return_value = patched