diff options
Diffstat (limited to 'python/pyspark/streaming')
-rw-r--r-- | python/pyspark/streaming/context.py | 54 | ||||
-rw-r--r-- | python/pyspark/streaming/flume.py | 2 | ||||
-rw-r--r-- | python/pyspark/streaming/kafka.py | 2 | ||||
-rw-r--r-- | python/pyspark/streaming/kinesis.py | 2 | ||||
-rw-r--r-- | python/pyspark/streaming/mqtt.py | 2 | ||||
-rw-r--r-- | python/pyspark/streaming/tests.py | 18 |
6 files changed, 24 insertions, 56 deletions
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index a8c9ffc235..975c754732 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -32,48 +32,6 @@ from pyspark.streaming.util import TransformFunction, TransformFunctionSerialize __all__ = ["StreamingContext"] -def _daemonize_callback_server(): - """ - Hack Py4J to daemonize callback server - - The thread of callback server has daemon=False, it will block the driver - from exiting if it's not shutdown. The following code replace `start()` - of CallbackServer with a new version, which set daemon=True for this - thread. - - Also, it will update the port number (0) with real port - """ - # TODO: create a patch for Py4J - import socket - import py4j.java_gateway - logger = py4j.java_gateway.logger - from py4j.java_gateway import Py4JNetworkError - from threading import Thread - - def start(self): - """Starts the CallbackServer. This method should be called by the - client instead of run().""" - self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, - 1) - try: - self.server_socket.bind((self.address, self.port)) - if not self.port: - # update port with real port - self.port = self.server_socket.getsockname()[1] - except Exception as e: - msg = 'An error occurred while trying to start the callback server: %s' % e - logger.exception(msg) - raise Py4JNetworkError(msg) - - # Maybe thread needs to be cleanup up? - self.thread = Thread(target=self.run) - self.thread.daemon = True - self.thread.start() - - py4j.java_gateway.CallbackServer.start = start - - class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext @@ -123,10 +81,14 @@ class StreamingContext(object): # start callback server # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__: - _daemonize_callback_server() - # use random port - gw._start_callback_server(0) + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport # gateway with real port gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index c0cdc50d8d..b3d1905365 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -20,7 +20,7 @@ if sys.version >= "3": from io import BytesIO else: from StringIO import StringIO -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8a814c64c0..b35bbaf404 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,7 +15,7 @@ # limitations under the License. # -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 34be5880e1..af72c3d690 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -15,7 +15,7 @@ # limitations under the License. # -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.storagelevel import StorageLevel diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index fa83006c36..1ce4093196 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -15,7 +15,7 @@ # limitations under the License. # -from py4j.java_gateway import Py4JJavaError +from py4j.protocol import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import UTF8Deserializer diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index e4e56fff3b..49634252fd 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -61,9 +61,12 @@ class PySparkStreamingTestCase(unittest.TestCase): def tearDownClass(cls): cls.sc.stop() # Clean up in the JVM just in case there has been some issues in Python API - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() + try: + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + except: + pass def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) @@ -72,9 +75,12 @@ class PySparkStreamingTestCase(unittest.TestCase): if self.ssc is not None: self.ssc.stop(False) # Clean up in the JVM just in case there has been some issues in Python API - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop(False) + try: + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) + except: + pass def wait_for(self, result, n): start_time = time.time() |