aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2017-03-03 16:43:45 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-03 16:43:45 -0800
commit44281ca81d4eda02b627ba21841108438b7d1c27 (patch)
tree4125cfa2e8dd98e247ae7240d88f3845ce871734 /python/pyspark/tests.py
parent2a7921a813ecd847fd933ffef10edc64684e9df7 (diff)
downloadspark-44281ca81d4eda02b627ba21841108438b7d1c27.tar.gz
spark-44281ca81d4eda02b627ba21841108438b7d1c27.tar.bz2
spark-44281ca81d4eda02b627ba21841108438b7d1c27.zip
[SPARK-19348][PYTHON] PySpark keyword_only decorator is not thread-safe
## What changes were proposed in this pull request? The `keyword_only` decorator in PySpark is not thread-safe. It writes kwargs to a static class variable in the decorator, which is then retrieved later in the class method as `_input_kwargs`. If multiple threads are constructing the same class with different kwargs, it becomes a race condition to read from the static class variable before it's overwritten. See [SPARK-19348](https://issues.apache.org/jira/browse/SPARK-19348) for reproduction code. This change will write the kwargs to a member variable so that multiple threads can operate on separate instances without the race condition. It does not protect against multiple threads operating on a single instance, but that is better left to the user to synchronize. ## How was this patch tested? Added new unit tests for using the keyword_only decorator and a regression test that verifies `_input_kwargs` can be overwritten from different class instances. Author: Bryan Cutler <cutlerb@gmail.com> Closes #16782 from BryanCutler/pyspark-keyword_only-threadsafe-SPARK-19348.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a2aead7e6b..c6c87a9ea5 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -58,6 +58,7 @@ else:
from StringIO import StringIO
+from pyspark import keyword_only
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
@@ -2161,6 +2162,44 @@ class ConfTests(unittest.TestCase):
sc.stop()
+class KeywordOnlyTests(unittest.TestCase):
+ class Wrapped(object):
+ @keyword_only
+ def set(self, x=None, y=None):
+ if "x" in self._input_kwargs:
+ self._x = self._input_kwargs["x"]
+ if "y" in self._input_kwargs:
+ self._y = self._input_kwargs["y"]
+ return x, y
+
+ def test_keywords(self):
+ w = self.Wrapped()
+ x, y = w.set(y=1)
+ self.assertEqual(y, 1)
+ self.assertEqual(y, w._y)
+ self.assertIsNone(x)
+ self.assertFalse(hasattr(w, "_x"))
+
+ def test_non_keywords(self):
+ w = self.Wrapped()
+ self.assertRaises(TypeError, lambda: w.set(0, y=1))
+
+ def test_kwarg_ownership(self):
+ # test _input_kwargs is owned by each class instance and not a shared static variable
+ class Setter(object):
+ @keyword_only
+ def set(self, x=None, other=None, other_x=None):
+ if "other" in self._input_kwargs:
+ self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
+ self._x = self._input_kwargs["x"]
+
+ a = Setter()
+ b = Setter()
+ a.set(x=1, other=b, other_x=2)
+ self.assertEqual(a._x, 1)
+ self.assertEqual(b._x, 2)
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):