From b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 18 Apr 2016 11:52:29 -0700 Subject: [SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin Closes #12439 from yinxusen/SPARK-14306-0415. --- python/pyspark/ml/classification.py | 142 ++++++++++++++++++++++++++++++------ 1 file changed, 120 insertions(+), 22 deletions(-) (limited to 'python/pyspark/ml/classification.py') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 089316729c..de1321b139 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -23,7 +23,7 @@ from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR return self._call_java("weights") +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Parameters for OneVsRest and OneVsRestModel. + """ + + classifier = Param(Params._dummy(), "classifier", "base binary classifier") + + @since("2.0.0") + def setClassifier(self, value): + """ + Sets the value of :py:attr:`classifier`. + + .. note:: Only LogisticRegression and NaiveBayes are supported now. + """ + self._set(classifier=value) + return self + + @since("2.0.0") + def getClassifier(self): + """ + Gets the value of classifier or its default value. + """ + return self.getOrDefault(self.classifier) + + @inherit_doc -class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): """ Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strategy. @@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): .. versionadded:: 2.0.0 """ - classifier = Param(Params._dummy(), "classifier", "base binary classifier") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", classifier=None): @@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("2.0.0") - def setClassifier(self, value): - """ - Sets the value of :py:attr:`classifier`. - - .. note:: Only LogisticRegression and NaiveBayes are supported now. - """ - self._set(classifier=value) - return self - - @since("2.0.0") - def getClassifier(self): - """ - Gets the value of classifier or its default value. - """ - return self.getOrDefault(self.classifier) - def _fit(self, dataset): labelCol = self.getLabelCol() featuresCol = self.getFeaturesCol() @@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): newOvr.setClassifier(self.getClassifier().copy(extra)) return newOvr + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRest, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol, + classifier=classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRest. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", + self.uid) + _java_obj.setClassifier(self.getClassifier()._to_java()) + _java_obj.setFeaturesCol(self.getFeaturesCol()) + _java_obj.setLabelCol(self.getLabelCol()) + _java_obj.setPredictionCol(self.getPredictionCol()) + return _java_obj -class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): + +class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): """ Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for each class. @@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): newModel.models = [model.copy(extra) for model in self.models] return newModel + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRestModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + models = [JavaParams._from_java(model) for model in java_stage.models()] + py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\ + .setFeaturesCol(featuresCol).setClassifier(classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRestModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + java_models = [model._to_java() for model in self.models] + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", + self.uid, java_models) + _java_obj.set("classifier", self.getClassifier()._to_java()) + _java_obj.set("featuresCol", self.getFeaturesCol()) + _java_obj.set("labelCol", self.getLabelCol()) + _java_obj.set("predictionCol", self.getPredictionCol()) + return _java_obj + if __name__ == "__main__": import doctest -- cgit v1.2.3