aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-12-01 13:22:40 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-12-01 13:22:40 -0800
commit78bb7f8071379114314c394e0167c4c5fd8545c5 (patch)
tree4c49f4fd69c635edf605ceda934c4b3f33595266 /python/pyspark/ml/wrapper.py
parente6534847100670a22b3b191a0f9d924fab7f3c02 (diff)
downloadspark-78bb7f8071379114314c394e0167c4c5fd8545c5.tar.gz
spark-78bb7f8071379114314c394e0167c4c5fd8545c5.tar.bz2
spark-78bb7f8071379114314c394e0167c4c5fd8545c5.zip
[SPARK-18274][ML][PYSPARK] Memory leak in PySpark JavaWrapper
## What changes were proposed in this pull request? In`JavaWrapper `'s destructor make Java Gateway dereference object in destructor, using `SparkContext._active_spark_context._gateway.detach` Fixing the copying parameter bug, by moving the `copy` method from `JavaModel` to `JavaParams` ## How was this patch tested? ```scala import random, string from pyspark.ml.feature import StringIndexer l = [(''.join(random.choice(string.ascii_uppercase) for _ in range(10)), ) for _ in range(int(7e5))] # 700000 random strings of 10 characters df = spark.createDataFrame(l, ['string']) for i in range(50): indexer = StringIndexer(inputCol='string', outputCol='index') indexer.fit(df) ``` * Before: would keep StringIndexer strong reference, causing GC issues and is halted midway After: garbage collection works as the object is dereferenced, and computation completes * Mem footprint tested using profiler * Added a parameter copy related test which was failing before. Author: Sandeep Singh <sandeep@techaddict.me> Author: jkbradley <joseph.kurata.bradley@gmail.com> Closes #15843 from techaddict/SPARK-18274.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py41
1 files changed, 23 insertions, 18 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 25c44b7533..13b75e9919 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -71,6 +71,10 @@ class JavaParams(JavaWrapper, Params):
__metaclass__ = ABCMeta
+ def __del__(self):
+ if SparkContext._active_spark_context:
+ SparkContext._active_spark_context._gateway.detach(self._java_obj)
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -180,6 +184,25 @@ class JavaParams(JavaWrapper, Params):
% stage_name)
return py_stage
+ def copy(self, extra=None):
+ """
+ Creates a copy of this instance with the same uid and some
+ extra params. This implementation first calls Params.copy and
+ then make a copy of the companion Java pipeline component with
+ extra params. So both the Python wrapper and the Java pipeline
+ component get copied.
+
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ if extra is None:
+ extra = dict()
+ that = super(JavaParams, self).copy(extra)
+ if self._java_obj is not None:
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
+ return that
+
@inherit_doc
class JavaEstimator(JavaParams, Estimator):
@@ -256,21 +279,3 @@ class JavaModel(JavaTransformer, Model):
super(JavaModel, self).__init__(java_model)
if java_model is not None:
self._resetUid(java_model.uid())
-
- def copy(self, extra=None):
- """
- Creates a copy of this instance with the same uid and some
- extra params. This implementation first calls Params.copy and
- then make a copy of the companion Java model with extra params.
- So both the Python wrapper and the Java model get copied.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- that = super(JavaModel, self).copy(extra)
- if self._java_obj is not None:
- that._java_obj = self._java_obj.copy(self._empty_java_param_map())
- that._transfer_params_to_java()
- return that