aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-03 19:08:39 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-03 19:08:39 -0700
commitc5cbc49233193836b321cb6b77ce69dae798570b (patch)
treecbab071adad31a6492e99674a9666d4e15c8283f /python
parent248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff)
downloadspark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz
spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2
spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF. Author: Davies Liu <davies.liu@gmail.com> Closes #2243 from davies/udf_broadcast and squashes the following commits: 7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py17
-rw-r--r--python/pyspark/tests.py22
2 files changed, 31 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 44316926ba..aaa35dadc2 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -942,9 +942,7 @@ class SQLContext:
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
-
- if sqlContext:
- self._scala_SQLContext = sqlContext
+ self._scala_SQLContext = sqlContext
@property
def _ssql_ctx(self):
@@ -953,7 +951,7 @@ class SQLContext:
Subclasses can override this property to provide their own
JVM Contexts.
"""
- if not hasattr(self, '_scala_SQLContext'):
+ if self._scala_SQLContext is None:
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
@@ -970,23 +968,26 @@ class SQLContext:
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
- >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
- [Row(c0=5)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
+ pickled_command = CloudPickleSerializer().dumps(command)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
+ self._sc._gateway._gateway_client)
+ self._sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
- bytearray(CloudPickleSerializer().dumps(command)),
+ bytearray(pickled_command),
env,
includes,
self._sc.pythonExec,
+ broadcast_vars,
self._sc._javaAccumulator,
str(returnType))
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f1a75cbff5..3e74799e82 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -43,6 +43,7 @@ from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
+from pyspark.sql import SQLContext, IntegerType
_have_scipy = False
_have_numpy = False
@@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase):
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+class TestSQL(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.sqlCtx = SQLContext(self.sc)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):