aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-09 16:24:06 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-03-09 16:24:06 -0700
commit8767565cef01d847f57b7293d8b63b2422009b90 (patch)
tree1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/rdd.py
parent3cac1991a1def0adaf42face2c578d3ab8c27025 (diff)
downloadspark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz
spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2
spark-8767565cef01d847f57b7293d8b63b2422009b90.zip
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb12fed98c..bf17f513c0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -19,7 +19,6 @@ import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
-import os
import sys
import shlex
from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@ import warnings
import heapq
import bisect
import random
+import socket
from math import sqrt, log, isinf, isnan, pow, ceil
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+def _load_from_socket(port, serializer):
+ sock = socket.socket()
+ try:
+ sock.connect(("localhost", port))
+ rf = sock.makefile("rb", 65536)
+ for item in serializer.load_stream(rf):
+ yield item
+ finally:
+ sock.close()
+
+
class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ class RDD(object):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(bytesInJava))
-
- def _collect_iterator_through_file(self, iterator):
- # Transferring lots of data through Py4J can be slow because
- # socket.readline() is inefficient. Instead, we'll dump the data to a
- # file and read it back.
- tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
- tempFile.close()
- self.ctx._writeToFile(iterator, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- for item in self._jrdd_deserializer.load_stream(tempFile):
- yield item
- os.unlink(tempFile.name)
+ port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(port, self._jrdd_deserializer))
def reduce(self, f):
"""