aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-13 14:08:57 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 14:08:57 -0700
commitfc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 (patch)
tree6ee38d5b95cb6bc5548c6bb1b8da528aa46ce1a9
parent781df499836e4216939e0febdcd5f89d30645759 (diff)
downloadspark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.gz
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.tar.bz2
spark-fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4.zip
[SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy
Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose. Ran existing Python ml tests and generated documentation to test this change. Author: Bryan Cutler <cutlerb@gmail.com> Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472.
-rw-r--r--python/pyspark/ml/classification.py4
-rw-r--r--python/pyspark/ml/evaluation.py4
-rw-r--r--python/pyspark/ml/pipeline.py10
-rw-r--r--python/pyspark/ml/regression.py4
-rw-r--r--python/pyspark/ml/tests.py4
-rw-r--r--python/pyspark/ml/tuning.py26
-rw-r--r--python/pyspark/ml/util.py4
-rw-r--r--python/pyspark/ml/wrapper.py76
8 files changed, 62 insertions, 70 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index e64c7a392b..922f8069fa 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,7 +19,7 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
@@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return BinaryLogisticRegressionSummary(java_blr_summary)
-class LogisticRegressionSummary(JavaCallable):
+class LogisticRegressionSummary(JavaWrapper):
"""
Abstraction for Logistic Regression Results for a given model.
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c9b95b3bf4..4b0bade102 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -18,7 +18,7 @@
from abc import abstractmethod, ABCMeta
from pyspark import since
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
@@ -81,7 +81,7 @@ class Evaluator(Params):
@inherit_doc
-class JavaEvaluator(Evaluator, JavaWrapper):
+class JavaEvaluator(JavaParams, Evaluator):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504bc29..9d654e8b0f 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import inherit_doc
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
# Create a new instance of this stage.
py_stage = cls()
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
py_stage.setStages(py_stages)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
for idx, stage in enumerate(self.getStages()):
java_stages[idx] = stage._to_java()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj.setStages(java_stages)
return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
Used for ML persistence.
"""
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
# Create a new instance of this stage.
py_stage = cls(py_stages)
py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
java_stages[idx] = stage._to_java()
_java_obj =\
- JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+ JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index bc88f88b7f..316d7e30bc 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,7 +20,7 @@ import warnings
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.mllib.common import inherit_doc
from pyspark.sql import DataFrame
@@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return LinearRegressionSummary(java_lr_summary)
-class LinearRegressionSummary(JavaCallable):
+class LinearRegressionSummary(JavaWrapper):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2dcd5eeb52..bcbeacbe80 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaWrapper):
+ if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ea8c61b7ef..456d79d897 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
from pyspark.mllib.common import inherit_doc, _py2java
@@ -148,8 +148,8 @@ class ValidatorParams(HasSeed):
"""
# Load information from java_stage to the instance.
- estimator = JavaWrapper._from_java(java_stage.getEstimator())
- evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+ estimator = JavaParams._from_java(java_stage.getEstimator())
+ evaluator = JavaParams._from_java(java_stage.getEvaluator())
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
return estimator, epms, evaluator
@@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel)\
@@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, []))
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
_java_obj.set("evaluator", evaluator)
@@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
- self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = \
super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
@@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj(
+ _java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d4411fdfb9..9dfcef0e40 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -99,7 +99,7 @@ class MLWriter(object):
@inherit_doc
class JavaMLWriter(MLWriter):
"""
- (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
@@ -178,7 +178,7 @@ class MLReader(object):
@inherit_doc
class JavaMLReader(MLReader):
"""
- (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""
def __init__(self, clazz):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cfe6f..cd0e5b80d5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
"""
- Utility class to help create wrapper classes from Java/Scala
- implementations of pipeline components.
+ Wrapper class for a Java companion object
"""
+ def __init__(self, java_obj=None):
+ super(JavaWrapper, self).__init__()
+ self._java_obj = java_obj
- __metaclass__ = ABCMeta
-
- def __init__(self):
+ @classmethod
+ def _create_from_java_class(cls, java_class, *args):
"""
- Initialize the wrapped java object to None
+ Construct this object from given Java classname and arguments
"""
- super(JavaWrapper, self).__init__()
- #: The wrapped Java companion object. Subclasses should initialize
- #: it properly. The param values in the Java object should be
- #: synced with the Python wrapper in fit/transform/evaluate/copy.
- self._java_obj = None
+ java_obj = JavaWrapper._new_java_obj(java_class, *args)
+ return cls(java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
@staticmethod
def _new_java_obj(java_class, *args):
"""
- Construct a new Java object.
+ Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+ #: The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+ __metaclass__ = ABCMeta
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
- if issubclass(py_type, JavaWrapper):
+ if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
@inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
@inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
-class JavaCallable(object):
- """
- Wrapper for a plain object in JVM to make Java calls, can be used
- as a mixin to another class that defines a _java_obj wrapper
- """
- def __init__(self, java_obj=None, sc=None):
- super(JavaCallable, self).__init__()
- self._sc = sc if sc is not None else SparkContext._active_spark_context
- # if this class is a mixin and _java_obj is already defined then don't initialize
- if java_obj is not None or not hasattr(self, "_java_obj"):
- self._java_obj = java_obj
-
- def __del__(self):
- if self._java_obj is not None:
- self._sc._gateway.detach(self._java_obj)
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- java_args = [_py2java(self._sc, arg) for arg in args]
- return _java2py(self._sc, m(*java_args))
-
-
@inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__()
+ super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self._java_obj = java_model
self.uid = java_model.uid()
def copy(self, extra=None):