aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-13 14:08:57 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 14:08:57 -0700
commitfc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 (patch)
tree6ee38d5b95cb6bc5548c6bb1b8da528aa46ce1a9 /python/pyspark/ml/wrapper.py
parent781df499836e4216939e0febdcd5f89d30645759 (diff)
downloadspark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.gz
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.bz2
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.zip
[SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy
Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose. Ran existing Python ml tests and generated documentation to test this change. Author: Bryan Cutler <cutlerb@gmail.com> Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py76
1 files changed, 34 insertions, 42 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cfe6f..cd0e5b80d5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
"""
- Utility class to help create wrapper classes from Java/Scala
- implementations of pipeline components.
+ Wrapper class for a Java companion object
"""
+ def __init__(self, java_obj=None):
+ super(JavaWrapper, self).__init__()
+ self._java_obj = java_obj
- __metaclass__ = ABCMeta
-
- def __init__(self):
+ @classmethod
+ def _create_from_java_class(cls, java_class, *args):
"""
- Initialize the wrapped java object to None
+ Construct this object from given Java classname and arguments
"""
- super(JavaWrapper, self).__init__()
- #: The wrapped Java companion object. Subclasses should initialize
- #: it properly. The param values in the Java object should be
- #: synced with the Python wrapper in fit/transform/evaluate/copy.
- self._java_obj = None
+ java_obj = JavaWrapper._new_java_obj(java_class, *args)
+ return cls(java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
@staticmethod
def _new_java_obj(java_class, *args):
"""
- Construct a new Java object.
+ Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+ #: The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+ __metaclass__ = ABCMeta
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
- if issubclass(py_type, JavaWrapper):
+ if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
@inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
@inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
-class JavaCallable(object):
- """
- Wrapper for a plain object in JVM to make Java calls, can be used
- as a mixin to another class that defines a _java_obj wrapper
- """
- def __init__(self, java_obj=None, sc=None):
- super(JavaCallable, self).__init__()
- self._sc = sc if sc is not None else SparkContext._active_spark_context
- # if this class is a mixin and _java_obj is already defined then don't initialize
- if java_obj is not None or not hasattr(self, "_java_obj"):
- self._java_obj = java_obj
-
- def __del__(self):
- if self._java_obj is not None:
- self._sc._gateway.detach(self._java_obj)
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- java_args = [_py2java(self._sc, arg) for arg in args]
- return _java2py(self._sc, m(*java_args))
-
-
@inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__()
+ super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self._java_obj = java_model
self.uid = java_model.uid()
def copy(self, extra=None):