diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-09-03 19:08:39 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-09-03 19:08:39 -0700 |
commit | c5cbc49233193836b321cb6b77ce69dae798570b (patch) | |
tree | cbab071adad31a6492e99674a9666d4e15c8283f /python | |
parent | 248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff) | |
download | spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2 spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip |
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2243 from davies/udf_broadcast and squashes the following commits:
7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql.py | 17 | ||||
-rw-r--r-- | python/pyspark/tests.py | 22 |
2 files changed, 31 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 44316926ba..aaa35dadc2 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -942,9 +942,7 @@ class SQLContext: self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): @@ -953,7 +951,7 @@ class SQLContext: Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -970,23 +968,26 @@ class SQLContext: >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1a75cbff5..3e74799e82 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -43,6 +43,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase): self.assertRaises(TypeError, lambda: rdd.histogram(2)) +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): |