diff options
author | Davies Liu <davies@databricks.com> | 2015-03-09 16:24:06 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-03-09 16:24:06 -0700 |
commit | 8767565cef01d847f57b7293d8b63b2422009b90 (patch) | |
tree | 1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/sql | |
parent | 3cac1991a1def0adaf42face2c578d3ab8c27025 (diff) | |
download | spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2 spark-8767565cef01d847f57b7293d8b63b2422009b90.zip |
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM.
This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #4923 from davies/fix_collect and squashes the following commits:
d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 14 |
1 files changed, 3 insertions, 11 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5c3b7377c3..e8ce454745 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -19,13 +19,11 @@ import sys import itertools import warnings import random -import os -from tempfile import NamedTemporaryFile from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -310,14 +308,8 @@ class DataFrame(object): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) + port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) cls = _create_cls(self.schema) return [cls(r) for r in rs] |