aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py30
1 files changed, 23 insertions, 7 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index ca93bf7d7d..a2cf2296fb 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
+class JavaCallable(object):
+ """
+ Wrapper for a plain object in JVM to make Java calls, can be used
+ as a mixin to another class that defines a _java_obj wrapper
+ """
+ def __init__(self, java_obj=None, sc=None):
+ super(JavaCallable, self).__init__()
+ self._sc = sc if sc is not None else SparkContext._active_spark_context
+ # if this class is a mixin and _java_obj is already defined then don't initialize
+ if java_obj is not None or not hasattr(self, "_java_obj"):
+ self._java_obj = java_obj
+
+ def __del__(self):
+ if self._java_obj is not None:
+ self._sc._gateway.detach(self._java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ java_args = [_py2java(self._sc, arg) for arg in args]
+ return _java2py(self._sc, m(*java_args))
+
+
@inherit_doc
-class JavaModel(Model, JavaTransformer):
+class JavaModel(Model, JavaCallable, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- sc = SparkContext._active_spark_context
- java_args = [_py2java(sc, arg) for arg in args]
- return _java2py(sc, m(*java_args))