aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-09 16:24:06 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-03-09 16:24:06 -0700
commit8767565cef01d847f57b7293d8b63b2422009b90 (patch)
tree1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /core
parent3cac1991a1def0adaf42face2c578d3ab8c27025 (diff)
downloadspark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz
spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2
spark-8767565cef01d847f57b7293d8b63b2422009b90.zip
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala76
1 files changed, 59 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1cec0f647..8d4a53b4ca 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
+
import org.apache.spark._
-import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.util.control.NonFatal
+
private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
- * This method will return an iterator of an array that contains all elements in the RDD
+ * This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
- allowLocal: Boolean): Iterator[Array[Byte]] = {
+ allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
- flattenedPartition.iterator
+ serveIterator(flattenedPartition.iterator,
+ s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+ }
+
+ /**
+ * A helper function to collect an RDD as an iterator, then serve it via socket.
+ *
+ * @return the port number of a local socket which serves the data collected from this job.
+ */
+ def collectAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}
- def writeToFile[T](items: java.util.Iterator[T], filename: String) {
- import scala.collection.JavaConverters._
- writeToFile(items.asScala, filename)
- }
+ /**
+ * Create a socket server and a background thread to serve the data in `items`,
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 3 seconds.
+ *
+ * Once a connection comes in, it tries to serialize all the data in `items`
+ * and send them into this connection.
+ *
+ * The thread will terminate after all the data are sent or any exceptions happen.
+ */
+ private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ val serverSocket = new ServerSocket(0, 1)
+ serverSocket.setReuseAddress(true)
+ // Close the socket if no connection in 3 seconds
+ serverSocket.setSoTimeout(3000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run() {
+ try {
+ val sock = serverSocket.accept()
+ val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ try {
+ writeIteratorToStream(items, out)
+ } finally {
+ out.close()
+ }
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Error while sending iterator", e)
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }.start()
- def writeToFile[T](items: Iterator[T], filename: String) {
- val file = new DataOutputStream(new FileOutputStream(filename))
- writeIteratorToStream(items, file)
- file.close()
+ serverSocket.getLocalPort
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],