aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-01-31 18:03:39 -0800
committerHolden Karau <holden@us.ibm.com>2017-01-31 18:03:39 -0800
commit9063835803e54538c94d95bbddcb4810dd7a1c55 (patch)
treed22369466eca165a1be27c44de96bef1ed9b8b3f /python
parent081b7addaf9560563af0ce25912972e91a78cee6 (diff)
downloadspark-9063835803e54538c94d95bbddcb4810dd7a1c55.tar.gz
spark-9063835803e54538c94d95bbddcb4810dd7a1c55.tar.bz2
spark-9063835803e54538c94d95bbddcb4810dd7a1c55.zip
[SPARK-19163][PYTHON][SQL] Delay _judf initialization to the __call__
## What changes were proposed in this pull request? Defer `UserDefinedFunction._judf` initialization to the first call. This prevents unintended `SparkSession` initialization. This allows users to define and import UDF without creating a context / session as a side effect. [SPARK-19163](https://issues.apache.org/jira/browse/SPARK-19163) ## How was this patch tested? Unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #16536 from zero323/SPARK-19163.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py35
-rw-r--r--python/pyspark/sql/tests.py44
2 files changed, 68 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 66d993a814..02c2350dc2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1826,25 +1826,38 @@ class UserDefinedFunction(object):
def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
- self._judf = self._create_judf(name)
-
- def _create_judf(self, name):
+ # Stores UserDefinedPythonFunctions jobj, once initialized
+ self._judf_placeholder = None
+ self._name = name or (
+ func.__name__ if hasattr(func, '__name__')
+ else func.__class__.__name__)
+
+ @property
+ def _judf(self):
+ # It is possible that concurrent access, to newly created UDF,
+ # will initialize multiple UserDefinedPythonFunctions.
+ # This is unlikely, doesn't affect correctness,
+ # and should have a minimal performance impact.
+ if self._judf_placeholder is None:
+ self._judf_placeholder = self._create_judf()
+ return self._judf_placeholder
+
+ def _create_judf(self):
from pyspark.sql import SparkSession
- sc = SparkContext.getOrCreate()
- wrapped_func = _wrap_function(sc, self.func, self.returnType)
+
spark = SparkSession.builder.getOrCreate()
+ sc = spark.sparkContext
+
+ wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
- if name is None:
- f = self.func
- name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- name, wrapped_func, jdt)
+ self._name, wrapped_func, jdt)
return judf
def __call__(self, *cols):
+ judf = self._judf
sc = SparkContext._active_spark_context
- jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
- return Column(jc)
+ return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
@since(1.3)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a88e5a1cfb..2fea4ac41f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -468,6 +468,27 @@ class SQLTests(ReusedPySparkTestCase):
row2 = df2.select(sameText(df2['file'])).first()
self.assertTrue(row2[0].find("people.json") != -1)
+ def test_udf_defers_judf_initalization(self):
+ # This is separate of UDFInitializationTests
+ # to avoid context initialization
+ # when udf is called
+
+ from pyspark.sql.functions import UserDefinedFunction
+
+ f = UserDefinedFunction(lambda x: x, StringType())
+
+ self.assertIsNone(
+ f._judf_placeholder,
+ "judf should not be initialized before the first call."
+ )
+
+ self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
+
+ self.assertIsNotNone(
+ f._judf_placeholder,
+ "judf should be initialized after UDF has been called."
+ )
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
@@ -1947,6 +1968,29 @@ class SQLTests2(ReusedPySparkTestCase):
df.collect()
+class UDFInitializationTests(unittest.TestCase):
+ def tearDown(self):
+ if SparkSession._instantiatedSession is not None:
+ SparkSession._instantiatedSession.stop()
+
+ if SparkContext._active_spark_context is not None:
+ SparkContext._active_spark_contex.stop()
+
+ def test_udf_init_shouldnt_initalize_context(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ UserDefinedFunction(lambda x: x, StringType())
+
+ self.assertIsNone(
+ SparkContext._active_spark_context,
+ "SparkContext shouldn't be initialized when UserDefinedFunction is created."
+ )
+ self.assertIsNone(
+ SparkSession._instantiatedSession,
+ "SparkSession shouldn't be initialized when UserDefinedFunction is created."
+ )
+
+
class HiveContextSQLTests(ReusedPySparkTestCase):
@classmethod