From 58ee2a2e47948a895e557fbcabbeadb31f0a1022 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Jun 2015 16:17:46 -0700 Subject: [SPARK-8738] [SQL] [PYSPARK] capture SQL AnalysisException in Python API Capture the AnalysisException in SQL, hide the long java stack trace, only show the error message. cc rxin Author: Davies Liu Closes #7135 from davies/ananylis and squashes the following commits: dad7ae7 [Davies Liu] add comment ec0c0e8 [Davies Liu] Update utils.py cdd7edd [Davies Liu] add doc 7b044c2 [Davies Liu] fix python 3 f84d3bd [Davies Liu] capture SQL AnalysisException in Python API --- python/pyspark/rdd.py | 3 ++- python/pyspark/sql/context.py | 2 ++ python/pyspark/sql/tests.py | 7 ++++++ python/pyspark/sql/utils.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 python/pyspark/sql/utils.py diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb20bc8b54..79dafb0a4e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -126,11 +126,12 @@ def _load_from_socket(port, serializer): # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) try: - sock = socket.socket(af, socktype, proto) sock.settimeout(3) sock.connect(sa) except socket.error: + sock.close() sock = None continue break diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4dda3b430c..4bf232111c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -33,6 +33,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.utils import install_exception_handler try: import pandas @@ -96,6 +97,7 @@ class SQLContext(object): self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) + install_exception_handler() @property def _ssql_ctx(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 34f397d0ff..5af2ce09bc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction from pyspark.sql.window import Window +from pyspark.sql.utils import AnalysisException class UTC(datetime.tzinfo): @@ -847,6 +848,12 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + # RuntimeException should not be captured + self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py new file mode 100644 index 0000000000..8096802e73 --- /dev/null +++ b/python/pyspark/sql/utils.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j + + +class AnalysisException(Exception): + """ + Failed to analyze a SQL query plan. + """ + + +def capture_sql_exception(f): + def deco(*a, **kw): + try: + return f(*a, **kw) + except py4j.protocol.Py4JJavaError as e: + cls, msg = e.java_exception.toString().split(': ', 1) + if cls == 'org.apache.spark.sql.AnalysisException': + raise AnalysisException(msg) + raise + return deco + + +def install_exception_handler(): + """ + Hook an exception handler into Py4j, which could capture some SQL exceptions in Java. + + When calling Java API, it will call `get_return_value` to parse the returned object. + If any exception happened in JVM, the result will be Java exception object, it raise + py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that + could capture the Java exception and throw a Python one (with the same error message). + + It's idempotent, could be called multiple times. + """ + original = py4j.protocol.get_return_value + # The original `get_return_value` is not patched, it's idempotent. + patched = capture_sql_exception(original) + # only patch the one used in in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched -- cgit v1.2.3