aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py125
1 files changed, 80 insertions, 45 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index dda6c6aba3..4419e16184 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -45,46 +45,61 @@ class JavaWrapper(Params):
__metaclass__ = ABCMeta
- #: Fully-qualified class name of the wrapped Java component.
- _java_class = None
+ #: The wrapped Java companion object. Subclasses should initialize
+ #: it properly. The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+ _java_obj = None
- def _java_obj(self):
+ @staticmethod
+ def _new_java_obj(java_class, *args):
"""
- Returns or creates a Java object.
+ Construct a new Java object.
"""
+ sc = SparkContext._active_spark_context
java_obj = _jvm()
- for name in self._java_class.split("."):
+ for name in java_class.split("."):
java_obj = getattr(java_obj, name)
- return java_obj()
+ java_args = [_py2java(sc, arg) for arg in args]
+ return java_obj(*java_args)
- def _transfer_params_to_java(self, params, java_obj):
+ def _make_java_param_pair(self, param, value):
"""
- Transforms the embedded params and additional params to the
- input Java object.
- :param params: additional params (overwriting embedded values)
- :param java_obj: Java object to receive the params
+ Makes a Java parm pair.
+ """
+ sc = SparkContext._active_spark_context
+ param = self._resolveParam(param)
+ java_param = self._java_obj.getParam(param.name)
+ java_value = _py2java(sc, value)
+ return java_param.w(java_value)
+
+ def _transfer_params_to_java(self):
+ """
+ Transforms the embedded params to the companion Java object.
"""
- paramMap = self.extractParamMap(params)
+ paramMap = self.extractParamMap()
for param in self.params:
if param in paramMap:
- value = paramMap[param]
- java_param = java_obj.getParam(param.name)
- java_obj.set(java_param.w(value))
+ pair = self._make_java_param_pair(param, paramMap[param])
+ self._java_obj.set(pair)
+
+ def _transfer_params_from_java(self):
+ """
+ Transforms the embedded params from the companion Java object.
+ """
+ sc = SparkContext._active_spark_context
+ for param in self.params:
+ if self._java_obj.hasParam(param.name):
+ java_param = self._java_obj.getParam(param.name)
+ value = _java2py(sc, self._java_obj.getOrDefault(java_param))
+ self._paramMap[param] = value
- def _empty_java_param_map(self):
+ @staticmethod
+ def _empty_java_param_map():
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _create_java_param_map(self, params, java_obj):
- paramMap = self._empty_java_param_map()
- for param, value in params.items():
- if param.parent is self:
- java_param = java_obj.getParam(param.name)
- paramMap.put(java_param.w(value))
- return paramMap
-
@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
@@ -99,9 +114,9 @@ class JavaEstimator(Estimator, JavaWrapper):
"""
Creates a model from the input Java model reference.
"""
- return JavaModel(java_model)
+ raise NotImplementedError()
- def _fit_java(self, dataset, params={}):
+ def _fit_java(self, dataset):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
@@ -109,12 +124,11 @@ class JavaEstimator(Estimator, JavaWrapper):
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return java_obj.fit(dataset._jdf, self._empty_java_param_map())
+ self._transfer_params_to_java()
+ return self._java_obj.fit(dataset._jdf)
- def fit(self, dataset, params={}):
- java_model = self._fit_java(dataset, params)
+ def _fit(self, dataset):
+ java_model = self._fit_java(dataset)
return self._create_model(java_model)
@@ -127,30 +141,47 @@ class JavaTransformer(Transformer, JavaWrapper):
__metaclass__ = ABCMeta
- def transform(self, dataset, params={}):
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
+ def _transform(self, dataset):
+ self._transfer_params_to_java()
+ return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
@inherit_doc
class JavaModel(Model, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
- implementations.
+ implementations. Subclasses should inherit this class before
+ param mix-ins, because this sets the UID from the Java model.
"""
__metaclass__ = ABCMeta
def __init__(self, java_model):
- super(JavaTransformer, self).__init__()
- self._java_model = java_model
+ """
+ Initialize this instance with a Java model object.
+ Subclasses should call this constructor, initialize params,
+ and then call _transformer_params_from_java.
+ """
+ super(JavaModel, self).__init__()
+ self._java_obj = java_model
+ self.uid = java_model.uid()
- def _java_obj(self):
- return self._java_model
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with the same uid and some
+ extra params. This implementation first calls Params.copy and
+ then make a copy of the companion Java model with extra params.
+ So both the Python wrapper and the Java model get copied.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ that = super(JavaModel, self).copy(extra)
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
+ return that
def _call_java(self, name, *args):
- m = getattr(self._java_model, name)
+ m = getattr(self._java_obj, name)
sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))
@@ -165,7 +196,11 @@ class JavaEvaluator(Evaluator, JavaWrapper):
__metaclass__ = ABCMeta
- def evaluate(self, dataset, params={}):
- java_obj = self._java_obj()
- self._transfer_params_to_java(params, java_obj)
- return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
+ def _evaluate(self, dataset):
+ """
+ Evaluates the output.
+ :param dataset: a dataset that contains labels/observations and predictions.
+ :return: evaluation metric
+ """
+ self._transfer_params_to_java()
+ return self._java_obj.evaluate(dataset._jdf)