diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-04-18 11:52:29 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-18 11:52:29 -0700 |
commit | b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch) | |
tree | 18131b3a63a970be653d9350785dc0ab0bcbbfff /python | |
parent | 775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff) | |
download | spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2 spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip |
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-14306
Add PySpark OneVsRest save/load supports.
## How was this patch tested?
Test with Python unit test.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 142 | ||||
-rw-r--r-- | python/pyspark/ml/tests.py | 25 |
2 files changed, 144 insertions, 23 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 089316729c..de1321b139 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -23,7 +23,7 @@ from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR return self._call_java("weights") +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Parameters for OneVsRest and OneVsRestModel. + """ + + classifier = Param(Params._dummy(), "classifier", "base binary classifier") + + @since("2.0.0") + def setClassifier(self, value): + """ + Sets the value of :py:attr:`classifier`. + + .. note:: Only LogisticRegression and NaiveBayes are supported now. + """ + self._set(classifier=value) + return self + + @since("2.0.0") + def getClassifier(self): + """ + Gets the value of classifier or its default value. + """ + return self.getOrDefault(self.classifier) + + @inherit_doc -class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): """ Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strategy. @@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): .. versionadded:: 2.0.0 """ - classifier = Param(Params._dummy(), "classifier", "base binary classifier") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", classifier=None): @@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("2.0.0") - def setClassifier(self, value): - """ - Sets the value of :py:attr:`classifier`. - - .. note:: Only LogisticRegression and NaiveBayes are supported now. - """ - self._set(classifier=value) - return self - - @since("2.0.0") - def getClassifier(self): - """ - Gets the value of classifier or its default value. - """ - return self.getOrDefault(self.classifier) - def _fit(self, dataset): labelCol = self.getLabelCol() featuresCol = self.getFeaturesCol() @@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): newOvr.setClassifier(self.getClassifier().copy(extra)) return newOvr + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRest, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol, + classifier=classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRest. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", + self.uid) + _java_obj.setClassifier(self.getClassifier()._to_java()) + _java_obj.setFeaturesCol(self.getFeaturesCol()) + _java_obj.setLabelCol(self.getLabelCol()) + _java_obj.setPredictionCol(self.getPredictionCol()) + return _java_obj -class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): + +class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): """ Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for each class. @@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): newModel.models = [model.copy(extra) for model in self.models] return newModel + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRestModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + models = [JavaParams._from_java(model) for model in java_stage.models()] + py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\ + .setFeaturesCol(featuresCol).setClassifier(classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRestModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + java_models = [model._to_java() for model in self.models] + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", + self.uid, java_models) + _java_obj.set("classifier", self.getClassifier()._to_java()) + _java_obj.set("featuresCol", self.getFeaturesCol()) + _java_obj.set("labelCol", self.getLabelCol()) + _java_obj.set("predictionCol", self.getPredictionCol()) + return _java_obj + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a7a9868bac..9d6ff47b54 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -43,7 +43,8 @@ import tempfile import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer -from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest +from pyspark.ml.classification import ( + LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel) from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * @@ -881,6 +882,28 @@ class OneVsRestTests(PySparkTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol()) + self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol()) + self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + for m, n in zip(model.models, loadedModel.models): + self.assertEqual(m.uid, n.uid) + class HashingTFTest(PySparkTestCase): |