aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-09 16:24:06 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-03-09 16:24:06 -0700
commit8767565cef01d847f57b7293d8b63b2422009b90 (patch)
tree1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f
parent3cac1991a1def0adaf42face2c578d3ab8c27025 (diff)
downloadspark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz
spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2
spark-8767565cef01d847f57b7293d8b63b2422009b90.zip
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala76
-rw-r--r--python/pyspark/context.py13
-rw-r--r--python/pyspark/rdd.py30
-rw-r--r--python/pyspark/sql/dataframe.py14
4 files changed, 82 insertions, 51 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1cec0f647..8d4a53b4ca 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
+
import org.apache.spark._
-import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.util.control.NonFatal
+
private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
- * This method will return an iterator of an array that contains all elements in the RDD
+ * This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
- allowLocal: Boolean): Iterator[Array[Byte]] = {
+ allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
- flattenedPartition.iterator
+ serveIterator(flattenedPartition.iterator,
+ s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+ }
+
+ /**
+ * A helper function to collect an RDD as an iterator, then serve it via socket.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
+ */
+ def collectAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}
- def writeToFile[T](items: java.util.Iterator[T], filename: String) {
- import scala.collection.JavaConverters._
- writeToFile(items.asScala, filename)
- }
+ /**
+ * Create a socket server and a background thread to serve the data in `items`,
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 3 seconds.
+ *
+ * Once a connection comes in, it tries to serialize all the data in `items`
+ * and send them into this connection.
+ *
+ * The thread will terminate after all the data are sent or any exceptions happen.
+ */
+ private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ val serverSocket = new ServerSocket(0, 1)
+ serverSocket.setReuseAddress(true)
+ // Close the socket if no connection in 3 seconds
+ serverSocket.setSoTimeout(3000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run() {
+ try {
+ val sock = serverSocket.accept()
+ val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ try {
+ writeIteratorToStream(items, out)
+ } finally {
+ out.close()
+ }
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Error while sending iterator", e)
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }.start()
- def writeToFile[T](items: Iterator[T], filename: String) {
- val file = new DataOutputStream(new FileOutputStream(filename))
- writeIteratorToStream(items, file)
- file.close()
+ serverSocket.getLocalPort
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6011caf9f1..78dccc4047 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,8 @@ import sys
from threading import Lock
from tempfile import NamedTemporaryFile
+from py4j.java_collections import ListConverter
+
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
-from py4j.java_collections import ListConverter
-
__all__ = ['SparkContext']
@@ -59,7 +59,6 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -221,7 +220,6 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
if instance:
if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ class SparkContext(object):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
- return list(mappedRDD._collect_iterator_through_file(it))
+ port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+ allowLocal)
+ return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb12fed98c..bf17f513c0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -19,7 +19,6 @@ import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
-import os
import sys
import shlex
from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@ import warnings
import heapq
import bisect
import random
+import socket
from math import sqrt, log, isinf, isnan, pow, ceil
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+def _load_from_socket(port, serializer):
+ sock = socket.socket()
+ try:
+ sock.connect(("localhost", port))
+ rf = sock.makefile("rb", 65536)
+ for item in serializer.load_stream(rf):
+ yield item
+ finally:
+ sock.close()
+
+
class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ class RDD(object):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(bytesInJava))
-
- def _collect_iterator_through_file(self, iterator):
- # Transferring lots of data through Py4J can be slow because
- # socket.readline() is inefficient. Instead, we'll dump the data to a
- # file and read it back.
- tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
- tempFile.close()
- self.ctx._writeToFile(iterator, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- for item in self._jrdd_deserializer.load_stream(tempFile):
- yield item
- os.unlink(tempFile.name)
+ port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(port, self._jrdd_deserializer))
def reduce(self, f):
"""
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5c3b7377c3..e8ce454745 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -19,13 +19,11 @@ import sys
import itertools
import warnings
import random
-import os
-from tempfile import NamedTemporaryFile
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- bytesInJava = self._jdf.javaToPython().collect().iterator()
- tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
- tempFile.close()
- self._sc._writeToFile(bytesInJava, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
- os.unlink(tempFile.name)
+ port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
+ rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]