aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-18 11:52:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-18 11:52:29 -0700
commitb64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch)
tree18131b3a63a970be653d9350785dc0ab0bcbbfff /python/pyspark/ml/classification.py
parent775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff)
downloadspark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py142
1 files changed, 120 insertions, 22 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 089316729c..de1321b139 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -23,7 +23,7 @@ from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
from pyspark.sql import DataFrame
@@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR
return self._call_java("weights")
+class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
+ """
+ Parameters for OneVsRest and OneVsRestModel.
+ """
+
+ classifier = Param(Params._dummy(), "classifier", "base binary classifier")
+
+ @since("2.0.0")
+ def setClassifier(self, value):
+ """
+ Sets the value of :py:attr:`classifier`.
+
+ .. note:: Only LogisticRegression and NaiveBayes are supported now.
+ """
+ self._set(classifier=value)
+ return self
+
+ @since("2.0.0")
+ def getClassifier(self):
+ """
+ Gets the value of classifier or its default value.
+ """
+ return self.getOrDefault(self.classifier)
+
+
@inherit_doc
-class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
"""
Reduction of Multiclass Classification to Binary Classification.
Performs reduction using one against all strategy.
@@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
.. versionadded:: 2.0.0
"""
- classifier = Param(Params._dummy(), "classifier", "base binary classifier")
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
@@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
- @since("2.0.0")
- def setClassifier(self, value):
- """
- Sets the value of :py:attr:`classifier`.
-
- .. note:: Only LogisticRegression and NaiveBayes are supported now.
- """
- self._set(classifier=value)
- return self
-
- @since("2.0.0")
- def getClassifier(self):
- """
- Gets the value of classifier or its default value.
- """
- return self.getOrDefault(self.classifier)
-
def _fit(self, dataset):
labelCol = self.getLabelCol()
featuresCol = self.getFeaturesCol()
@@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
newOvr.setClassifier(self.getClassifier().copy(extra))
return newOvr
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java OneVsRest, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ featuresCol = java_stage.getFeaturesCol()
+ labelCol = java_stage.getLabelCol()
+ predictionCol = java_stage.getPredictionCol()
+ classifier = JavaParams._from_java(java_stage.getClassifier())
+ py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
+ classifier=classifier)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java OneVsRest. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
+ self.uid)
+ _java_obj.setClassifier(self.getClassifier()._to_java())
+ _java_obj.setFeaturesCol(self.getFeaturesCol())
+ _java_obj.setLabelCol(self.getLabelCol())
+ _java_obj.setPredictionCol(self.getPredictionCol())
+ return _java_obj
-class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+
+class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
"""
Model fitted by OneVsRest.
This stores the models resulting from training k binary classifiers: one for each class.
@@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
newModel.models = [model.copy(extra) for model in self.models]
return newModel
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java OneVsRestModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ featuresCol = java_stage.getFeaturesCol()
+ labelCol = java_stage.getLabelCol()
+ predictionCol = java_stage.getPredictionCol()
+ classifier = JavaParams._from_java(java_stage.getClassifier())
+ models = [JavaParams._from_java(model) for model in java_stage.models()]
+ py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
+ .setFeaturesCol(featuresCol).setClassifier(classifier)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+ java_models = [model._to_java() for model in self.models]
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
+ self.uid, java_models)
+ _java_obj.set("classifier", self.getClassifier()._to_java())
+ _java_obj.set("featuresCol", self.getFeaturesCol())
+ _java_obj.set("labelCol", self.getLabelCol())
+ _java_obj.set("predictionCol", self.getPredictionCol())
+ return _java_obj
+
if __name__ == "__main__":
import doctest