aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f1a75cbff5..3e74799e82 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -43,6 +43,7 @@ from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
+from pyspark.sql import SQLContext, IntegerType
_have_scipy = False
_have_numpy = False
@@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase):
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+class TestSQL(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.sqlCtx = SQLContext(self.sc)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):