diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-09-03 19:08:39 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-09-03 19:08:39 -0700 |
commit | c5cbc49233193836b321cb6b77ce69dae798570b (patch) | |
tree | cbab071adad31a6492e99674a9666d4e15c8283f /python/pyspark/sql.py | |
parent | 248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff) | |
download | spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2 spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip |
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2243 from davies/udf_broadcast and squashes the following commits:
7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r-- | python/pyspark/sql.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 44316926ba..aaa35dadc2 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -942,9 +942,7 @@ class SQLContext: self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): @@ -953,7 +951,7 @@ class SQLContext: Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -970,23 +968,26 @@ class SQLContext: >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) |