aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83ef76c13c..e4f79c911c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction, sha2
from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, IllegalArgumentException
+from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
class UTCOffsetTimezone(datetime.tzinfo):
@@ -305,6 +305,25 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_chained_udf(self):
+ self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+ self.assertEqual(row[0], 2)
+ [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+ self.assertEqual(row[0], 4)
+ [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+ self.assertEqual(row[0], 6)
+
+ def test_multiple_udfs(self):
+ self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.assertEqual(tuple(row), (6, 5))
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
@@ -324,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
+ def test_udf_with_aggregate_function(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a == 1, BooleanType())
+ sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+ self.assertEqual(sel.collect(), [Row(key=1)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.read.json(rdd)
@@ -1130,7 +1158,9 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
+
+ def test_capture_parse_exception(self):
+ self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",