aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py76
1 files changed, 34 insertions, 42 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cfe6f..cd0e5b80d5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
"""
- Utility class to help create wrapper classes from Java/Scala
- implementations of pipeline components.
+ Wrapper class for a Java companion object
"""
+ def __init__(self, java_obj=None):
+ super(JavaWrapper, self).__init__()
+ self._java_obj = java_obj
- __metaclass__ = ABCMeta
-
- def __init__(self):
+ @classmethod
+ def _create_from_java_class(cls, java_class, *args):
"""
- Initialize the wrapped java object to None
+ Construct this object from given Java classname and arguments
"""
- super(JavaWrapper, self).__init__()
- #: The wrapped Java companion object. Subclasses should initialize
- #: it properly. The param values in the Java object should be
- #: synced with the Python wrapper in fit/transform/evaluate/copy.
- self._java_obj = None
+ java_obj = JavaWrapper._new_java_obj(java_class, *args)
+ return cls(java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
@staticmethod
def _new_java_obj(java_class, *args):
"""
- Construct a new Java object.
+ Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+ #: The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+ __metaclass__ = ABCMeta
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
- if issubclass(py_type, JavaWrapper):
+ if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
@inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
@inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
-class JavaCallable(object):
- """
- Wrapper for a plain object in JVM to make Java calls, can be used
- as a mixin to another class that defines a _java_obj wrapper
- """
- def __init__(self, java_obj=None, sc=None):
- super(JavaCallable, self).__init__()
- self._sc = sc if sc is not None else SparkContext._active_spark_context
- # if this class is a mixin and _java_obj is already defined then don't initialize
- if java_obj is not None or not hasattr(self, "_java_obj"):
- self._java_obj = java_obj
-
- def __del__(self):
- if self._java_obj is not None:
- self._sc._gateway.detach(self._java_obj)
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- java_args = [_py2java(self._sc, arg) for arg in args]
- return _java2py(self._sc, m(*java_args))
-
-
@inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__()
+ super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self._java_obj = java_model
self.uid = java_model.uid()
def copy(self, extra=None):