aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py17
-rw-r--r--python/pyspark/tests.py22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala3
4 files changed, 36 insertions, 9 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 44316926ba..aaa35dadc2 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -942,9 +942,7 @@ class SQLContext:
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
-
- if sqlContext:
- self._scala_SQLContext = sqlContext
+ self._scala_SQLContext = sqlContext
@property
def _ssql_ctx(self):
@@ -953,7 +951,7 @@ class SQLContext:
Subclasses can override this property to provide their own
JVM Contexts.
"""
- if not hasattr(self, '_scala_SQLContext'):
+ if self._scala_SQLContext is None:
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
@@ -970,23 +968,26 @@ class SQLContext:
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
- >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
- [Row(c0=5)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
+ pickled_command = CloudPickleSerializer().dumps(command)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
+ self._sc._gateway._gateway_client)
+ self._sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
- bytearray(CloudPickleSerializer().dumps(command)),
+ bytearray(pickled_command),
env,
includes,
self._sc.pythonExec,
+ broadcast_vars,
self._sc._javaAccumulator,
str(returnType))
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f1a75cbff5..3e74799e82 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -43,6 +43,7 @@ from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
+from pyspark.sql import SQLContext, IntegerType
_have_scipy = False
_have_numpy = False
@@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase):
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+class TestSQL(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.sqlCtx = SQLContext(self.sc)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 0b48e9e659..0ea1105f08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.util.{List => JList, Map => JMap}
import org.apache.spark.Accumulator
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf}
import org.apache.spark.sql.execution.PythonUDF
@@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration {
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
log.debug(
@@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration {
envVars,
pythonIncludes,
pythonExec,
+ broadcastVars,
accumulator,
dataType,
e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 3dc8be2456..0977da3e85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -42,6 +42,7 @@ private[spark] case class PythonUDF(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
children: Seq[Expression]) extends Expression with SparkLogging {
@@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
udf.pythonIncludes,
false,
udf.pythonExec,
- Seq[Broadcast[Array[Byte]]](),
+ udf.broadcastVars,
udf.accumulator
).mapPartitions { iter =>
val pickle = new Unpickler