diff options
author | Davies Liu <davies@databricks.com> | 2015-03-09 16:24:06 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-03-09 16:24:06 -0700 |
commit | 8767565cef01d847f57b7293d8b63b2422009b90 (patch) | |
tree | 1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/rdd.py | |
parent | 3cac1991a1def0adaf42face2c578d3ab8c27025 (diff) | |
download | spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2 spark-8767565cef01d847f57b7293d8b63b2422009b90.zip |
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM.
This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #4923 from davies/fix_collect and squashes the following commits:
d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb12fed98c..bf17f513c0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -19,7 +19,6 @@ import copy from collections import defaultdict from itertools import chain, ifilter, imap import operator -import os import sys import shlex from subprocess import Popen, PIPE @@ -29,6 +28,7 @@ import warnings import heapq import bisect import random +import socket from math import sqrt, log, isinf, isnan, pow, ceil from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ @@ -111,6 +111,17 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def _load_from_socket(port, serializer): + sock = socket.socket() + try: + sock.connect(("localhost", port)) + rf = sock.makefile("rb", 65536) + for item in serializer.load_stream(rf): + yield item + finally: + sock.close() + + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -698,21 +709,8 @@ class RDD(object): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(bytesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) - tempFile.close() - self.ctx._writeToFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in self._jrdd_deserializer.load_stream(tempFile): - yield item - os.unlink(tempFile.name) + port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(port, self._jrdd_deserializer)) def reduce(self, f): """ |