aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/classification.py4
-rw-r--r--python/pyspark/ml/evaluation.py4
-rw-r--r--python/pyspark/ml/pipeline.py10
-rw-r--r--python/pyspark/ml/regression.py4
-rw-r--r--python/pyspark/ml/tests.py4
-rw-r--r--python/pyspark/ml/tuning.py26
-rw-r--r--python/pyspark/ml/util.py4
-rw-r--r--python/pyspark/ml/wrapper.py76
8 files changed, 62 insertions, 70 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index e64c7a392b..922f8069fa 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,7 +19,7 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
@@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return BinaryLogisticRegressionSummary(java_blr_summary)
-class LogisticRegressionSummary(JavaCallable):
+class LogisticRegressionSummary(JavaWrapper):
"""
Abstraction for Logistic Regression Results for a given model.
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c9b95b3bf4..4b0bade102 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -18,7 +18,7 @@
from abc import abstractmethod, ABCMeta
from pyspark import since
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
@@ -81,7 +81,7 @@ class Evaluator(Params):
@inherit_doc
-class JavaEvaluator(Evaluator, JavaWrapper):
+class JavaEvaluator(JavaParams, Evaluator):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504bc29..9d654e8b0f 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import inherit_doc
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
# Create a new instance of this stage.
py_stage = cls()
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
py_stage.setStages(py_stages)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
for idx, stage in enumerate(self.getStages()):
java_stages[idx] = stage._to_java()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj.setStages(java_stages)
return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
Used for ML persistence.
"""
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
# Create a new instance of this stage.
py_stage = cls(py_stages)
py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
java_stages[idx] = stage._to_java()
_java_obj =\
- JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+ JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index bc88f88b7f..316d7e30bc 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,7 +20,7 @@ import warnings
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.mllib.common import inherit_doc
from pyspark.sql import DataFrame
@@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return LinearRegressionSummary(java_lr_summary)
-class LinearRegressionSummary(JavaCallable):
+class LinearRegressionSummary(JavaWrapper):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2dcd5eeb52..bcbeacbe80 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaWrapper):
+ if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ea8c61b7ef..456d79d897 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
from pyspark.mllib.common import inherit_doc, _py2java
@@ -148,8 +148,8 @@ class ValidatorParams(HasSeed):
"""
# Load information from java_stage to the instance.
- estimator = JavaWrapper._from_java(java_stage.getEstimator())
- evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+ estimator = JavaParams._from_java(java_stage.getEstimator())
+ evaluator = JavaParams._from_java(java_stage.getEvaluator())
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
return estimator, epms, evaluator
@@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel)\
@@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, []))
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
_java_obj.set("evaluator", evaluator)
@@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
- self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
@@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
# Load information from java_stage to the instance.
- bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = \
super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
@@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
- _java_obj = JavaWrapper._new_java_obj(
+ _java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d4411fdfb9..9dfcef0e40 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -99,7 +99,7 @@ class MLWriter(object):
@inherit_doc
class JavaMLWriter(MLWriter):
"""
- (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
@@ -178,7 +178,7 @@ class MLReader(object):
@inherit_doc
class JavaMLReader(MLReader):
"""
- (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""
def __init__(self, clazz):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cfe6f..cd0e5b80d5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
"""
- Utility class to help create wrapper classes from Java/Scala
- implementations of pipeline components.
+ Wrapper class for a Java companion object
"""
+ def __init__(self, java_obj=None):
+ super(JavaWrapper, self).__init__()
+ self._java_obj = java_obj
- __metaclass__ = ABCMeta
-
- def __init__(self):
+ @classmethod
+ def _create_from_java_class(cls, java_class, *args):
"""
- Initialize the wrapped java object to None
+ Construct this object from given Java classname and arguments
"""
- super(JavaWrapper, self).__init__()
- #: The wrapped Java companion object. Subclasses should initialize
- #: it properly. The param values in the Java object should be
- #: synced with the Python wrapper in fit/transform/evaluate/copy.
- self._java_obj = None
+ java_obj = JavaWrapper._new_java_obj(java_class, *args)
+ return cls(java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
@staticmethod
def _new_java_obj(java_class, *args):
"""
- Construct a new Java object.
+ Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+ #: The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+ __metaclass__ = ABCMeta
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
- if issubclass(py_type, JavaWrapper):
+ if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
@inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
@inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
-class JavaCallable(object):
- """
- Wrapper for a plain object in JVM to make Java calls, can be used
- as a mixin to another class that defines a _java_obj wrapper
- """
- def __init__(self, java_obj=None, sc=None):
- super(JavaCallable, self).__init__()
- self._sc = sc if sc is not None else SparkContext._active_spark_context
- # if this class is a mixin and _java_obj is already defined then don't initialize
- if java_obj is not None or not hasattr(self, "_java_obj"):
- self._java_obj = java_obj
-
- def __del__(self):
- if self._java_obj is not None:
- self._sc._gateway.detach(self._java_obj)
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- java_args = [_py2java(self._sc, arg) for arg in args]
- return _java2py(self._sc, m(*java_args))
-
-
@inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__()
+ super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self._java_obj = java_model
self.uid = java_model.uid()
def copy(self, extra=None):