aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-03-06 08:57:01 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-06 08:57:01 -0800
commitee913e6e2d58dfac20f3f06ff306081bd0e48066 (patch)
tree262f35891a14f8ae3cca03c4700c341fa3239bf6 /python/pyspark/tests.py
parent8ff88094daa4945e7d718baa7b20703fd8087ab0 (diff)
downloadspark-ee913e6e2d58dfac20f3f06ff306081bd0e48066.tar.gz
spark-ee913e6e2d58dfac20f3f06ff306081bd0e48066.tar.bz2
spark-ee913e6e2d58dfac20f3f06ff306081bd0e48066.zip
[SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunctionSerializer.loads
## What changes were proposed in this pull request? Set the function's module name to `__main__` if it's missing in `TransformFunctionSerializer.loads`. ## How was this patch tested? Manually test in the shell. Before this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction(<function <lambda> at 0x106ac8b18>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) None >>> print(func2.rdd_wrap_func.__module__) None >>> ``` After this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction(<function <lambda> at 0x108bf1b90>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) __main__ >>> print(func2.rdd_wrap_func.__module__) __main__ >>> ``` Author: Shixiong Zhu <shixiong@databricks.com> Closes #11535 from zsxwing/loads-module.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 23720502a8..a5a83c7e38 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -228,6 +228,12 @@ class SerializationTestCase(unittest.TestCase):
getter2 = ser.loads(ser.dumps(getter))
self.assertEqual(getter(d), getter2(d))
+ def test_function_module_name(self):
+ ser = CloudPickleSerializer()
+ func = lambda x: x
+ func2 = ser.loads(ser.dumps(func))
+ self.assertEqual(func.__module__, func2.__module__)
+
def test_attrgetter(self):
from operator import attrgetter
ser = CloudPickleSerializer()