aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-03 19:08:39 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-03 19:08:39 -0700
commitc5cbc49233193836b321cb6b77ce69dae798570b (patch)
treecbab071adad31a6492e99674a9666d4e15c8283f /python/pyspark/sql.py
parent248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff)
downloadspark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz
spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2
spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF. Author: Davies Liu <davies.liu@gmail.com> Closes #2243 from davies/udf_broadcast and squashes the following commits: 7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 44316926ba..aaa35dadc2 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -942,9 +942,7 @@ class SQLContext:
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
-
- if sqlContext:
- self._scala_SQLContext = sqlContext
+ self._scala_SQLContext = sqlContext
@property
def _ssql_ctx(self):
@@ -953,7 +951,7 @@ class SQLContext:
Subclasses can override this property to provide their own
JVM Contexts.
"""
- if not hasattr(self, '_scala_SQLContext'):
+ if self._scala_SQLContext is None:
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
@@ -970,23 +968,26 @@ class SQLContext:
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
- >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
- [Row(c0=5)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
+ pickled_command = CloudPickleSerializer().dumps(command)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
+ self._sc._gateway._gateway_client)
+ self._sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
- bytearray(CloudPickleSerializer().dumps(command)),
+ bytearray(pickled_command),
env,
includes,
self._sc.pythonExec,
+ broadcast_vars,
self._sc._javaAccumulator,
str(returnType))