aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala20
-rw-r--r--python/pyspark/accumulators.py34
2 files changed, 42 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a9d758bf99..94d666aa92 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -731,19 +731,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
+ /**
+ * We try to reuse a single Socket to transfer accumulator updates, as they are all added
+ * by the DAGScheduler's single-threaded actor anyway.
+ */
+ @transient var socket: Socket = _
+
+ def openSocket(): Socket = synchronized {
+ if (socket == null || socket.isClosed) {
+ socket = new Socket(serverHost, serverPort)
+ }
+ socket
+ }
+
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
- : JList[Array[Byte]] = {
+ : JList[Array[Byte]] = synchronized {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
- val socket = new Socket(serverHost, serverPort)
- // SPARK-2282: Immediately reuse closed sockets because we create one per task.
- socket.setReuseAddress(true)
+ val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
@@ -757,7 +768,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
- socket.close()
null
}
}
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 2204e9c9ca..45d36e5d0e 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -86,6 +86,7 @@ Traceback (most recent call last):
Exception:...
"""
+import select
import struct
import SocketServer
import threading
@@ -209,19 +210,38 @@ COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+ """
+ This handler will keep polling updates from the same socket until the
+ server is shutdown.
+ """
+
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
- num_updates = read_int(self.rfile)
- for _ in range(num_updates):
- (aid, update) = pickleSer._read_with_length(self.rfile)
- _accumulatorRegistry[aid] += update
- # Write a byte in acknowledgement
- self.wfile.write(struct.pack("!b", 1))
+ while not self.server.server_shutdown:
+ # Poll every 1 second for new data -- don't block in case of shutdown.
+ r, _, _ = select.select([self.rfile], [], [], 1)
+ if self.rfile in r:
+ num_updates = read_int(self.rfile)
+ for _ in range(num_updates):
+ (aid, update) = pickleSer._read_with_length(self.rfile)
+ _accumulatorRegistry[aid] += update
+ # Write a byte in acknowledgement
+ self.wfile.write(struct.pack("!b", 1))
+
+class AccumulatorServer(SocketServer.TCPServer):
+ """
+ A simple TCP server that intercepts shutdown() in order to interrupt
+ our continuous polling on the handler.
+ """
+ server_shutdown = False
+ def shutdown(self):
+ self.server_shutdown = True
+ SocketServer.TCPServer.shutdown(self)
def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
- server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
+ server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()