diff options
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1a75cbff5..3e74799e82 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -43,6 +43,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase): self.assertRaises(TypeError, lambda: rdd.histogram(2)) +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): |