diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83ef76c13c..e4f79c911c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException class UTCOffsetTimezone(datetime.tzinfo): @@ -305,6 +305,25 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + + def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) @@ -324,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_aggregate_function(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.sqlCtx.read.json(rdd) @@ -1130,7 +1158,9 @@ class SQLTests(ReusedPySparkTestCase): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", |