diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/tests.py | 11 | ||||
-rw-r--r-- | python/pyspark/sql/utils.py | 8 |
2 files changed, 17 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 241eac45cf..86706e2dc4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,9 +45,9 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException +from pyspark.sql.utils import AnalysisException, IllegalArgumentException class UTC(datetime.tzinfo): @@ -894,6 +894,13 @@ class SQLTests(ReusedPySparkTestCase): # RuntimeException should not be captured self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + def test_capture_illegalargument_exception(self): + self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", + lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) + df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", + lambda: df.select(sha2(df.a, 1024)).collect()) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cc5b2c088b..0f795ca35b 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -24,6 +24,12 @@ class AnalysisException(Exception): """ +class IllegalArgumentException(Exception): + """ + Passed an illegal or inappropriate argument. + """ + + def capture_sql_exception(f): def deco(*a, **kw): try: @@ -32,6 +38,8 @@ def capture_sql_exception(f): s = e.java_exception.toString() if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1]) + if s.startswith('java.lang.IllegalArgumentException: '): + raise IllegalArgumentException(s.split(': ', 1)[1]) raise return deco |