aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-09 16:24:06 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-03-09 16:24:06 -0700
commit8767565cef01d847f57b7293d8b63b2422009b90 (patch)
tree1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/sql/dataframe.py
parent3cac1991a1def0adaf42face2c578d3ab8c27025 (diff)
downloadspark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz
spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2
spark-8767565cef01d847f57b7293d8b63b2422009b90.zip
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py14
1 files changed, 3 insertions, 11 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5c3b7377c3..e8ce454745 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -19,13 +19,11 @@ import sys
import itertools
import warnings
import random
-import os
-from tempfile import NamedTemporaryFile
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- bytesInJava = self._jdf.javaToPython().collect().iterator()
- tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
- tempFile.close()
- self._sc._writeToFile(bytesInJava, tempFile.name)
- # Read the data into Python and deserialize it:
- with open(tempFile.name, 'rb') as tempFile:
- rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
- os.unlink(tempFile.name)
+ port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
+ rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]