diff options
author | Davies Liu <davies@databricks.com> | 2015-03-09 16:24:06 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-03-09 16:24:06 -0700 |
commit | 8767565cef01d847f57b7293d8b63b2422009b90 (patch) | |
tree | 1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/context.py | |
parent | 3cac1991a1def0adaf42face2c578d3ab8c27025 (diff) | |
download | spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2 spark-8767565cef01d847f57b7293d8b63b2422009b90.zip |
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM.
This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #4923 from davies/fix_collect and squashes the following commits:
d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r-- | python/pyspark/context.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6011caf9f1..78dccc4047 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,8 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile +from py4j.java_collections import ListConverter + from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler -from py4j.java_collections import ListConverter - __all__ = ['SparkContext'] @@ -59,7 +59,6 @@ class SparkContext(object): _gateway = None _jvm = None - _writeToFile = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -221,7 +220,6 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if (SparkContext._active_spark_context and @@ -840,8 +838,9 @@ class SparkContext(object): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) - return list(mappedRDD._collect_iterator_through_file(it)) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + allowLocal) + return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ |