aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-09 16:24:06 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-03-09 16:24:06 -0700
commit8767565cef01d847f57b7293d8b63b2422009b90 (patch)
tree1204ac7a7cda19b30e2a990ae2ded5f5b40b8c3f /python/pyspark/context.py
parent3cac1991a1def0adaf42face2c578d3ab8c27025 (diff)
downloadspark-8767565cef01d847f57b7293d8b63b2422009b90.tar.gz
spark-8767565cef01d847f57b7293d8b63b2422009b90.tar.bz2
spark-8767565cef01d847f57b7293d8b63b2422009b90.zip
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM. This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4923 from davies/fix_collect and squashes the following commits: d730286 [Davies Liu] address comments 24c92a4 [Davies Liu] fix style ba54614 [Davies Liu] use socket to transfer data from JVM 9517c8f [Davies Liu] fix memory leak in collect()
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6011caf9f1..78dccc4047 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,8 @@ import sys
from threading import Lock
from tempfile import NamedTemporaryFile
+from py4j.java_collections import ListConverter
+
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
-from py4j.java_collections import ListConverter
-
__all__ = ['SparkContext']
@@ -59,7 +59,6 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -221,7 +220,6 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
if instance:
if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ class SparkContext(object):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
- return list(mappedRDD._collect_iterator_through_file(it))
+ port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+ allowLocal)
+ return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """