aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/streaming/context.py33
-rw-r--r--python/pyspark/streaming/util.py3
2 files changed, 26 insertions, 10 deletions
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 3deed52be0..5cc4bbde39 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -98,8 +98,28 @@ class StreamingContext(object):
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
- cls._transformerSerializer = TransformFunctionSerializer(
- SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+ if cls._transformerSerializer is None:
+ transformer_serializer = TransformFunctionSerializer()
+ transformer_serializer.init(
+ SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+ # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM
+ # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever.
+ # (https://github.com/bartdag/py4j/pull/184)
+ #
+ # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when
+ # calling "registerSerializer". If we call "registerSerializer" twice, the second
+ # PythonProxyHandler will override the first one, then the first one will be GCed and
+ # trigger "PythonProxyHandler.finalize". To avoid that, we should not call
+ # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't
+ # be GCed.
+ #
+ # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version.
+ transformer_serializer.gateway.jvm.PythonDStream.registerSerializer(
+ transformer_serializer)
+ cls._transformerSerializer = transformer_serializer
+ else:
+ cls._transformerSerializer.init(
+ SparkContext._active_spark_context, CloudPickleSerializer(), gw)
@classmethod
def getOrCreate(cls, checkpointPath, setupFunc):
@@ -116,16 +136,13 @@ class StreamingContext(object):
gw = SparkContext._gateway
# Check whether valid checkpoint information exists in the given path
- if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty():
+ ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
+ if ssc_option.isEmpty():
ssc = setupFunc()
ssc.checkpoint(checkpointPath)
return ssc
- try:
- jssc = gw.jvm.JavaStreamingContext(checkpointPath)
- except Exception:
- print("failed to load StreamingContext from checkpoint", file=sys.stderr)
- raise
+ jssc = gw.jvm.JavaStreamingContext(ssc_option.get())
# If there is already an active instance of Python SparkContext use it, or create a new one
if not SparkContext._active_spark_context:
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index abbbf6eb93..e617fc9ce9 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -89,11 +89,10 @@ class TransformFunctionSerializer(object):
it uses this class to invoke Python, which returns the serialized function
as a byte array.
"""
- def __init__(self, ctx, serializer, gateway=None):
+ def init(self, ctx, serializer, gateway=None):
self.ctx = ctx
self.serializer = serializer
self.gateway = gateway or self.ctx._gateway
- self.gateway.jvm.PythonDStream.registerSerializer(self)
self.failure = None
def dumps(self, id):