aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-18 11:52:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-18 11:52:29 -0700
commitb64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch)
tree18131b3a63a970be653d9350785dc0ab0bcbbfff /python
parent775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff)
downloadspark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py142
-rw-r--r--python/pyspark/ml/tests.py25
2 files changed, 144 insertions, 23 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 089316729c..de1321b139 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -23,7 +23,7 @@ from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
from pyspark.sql import DataFrame
@@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR
return self._call_java("weights")
+class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
+ """
+ Parameters for OneVsRest and OneVsRestModel.
+ """
+
+ classifier = Param(Params._dummy(), "classifier", "base binary classifier")
+
+ @since("2.0.0")
+ def setClassifier(self, value):
+ """
+ Sets the value of :py:attr:`classifier`.
+
+ .. note:: Only LogisticRegression and NaiveBayes are supported now.
+ """
+ self._set(classifier=value)
+ return self
+
+ @since("2.0.0")
+ def getClassifier(self):
+ """
+ Gets the value of classifier or its default value.
+ """
+ return self.getOrDefault(self.classifier)
+
+
@inherit_doc
-class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
"""
Reduction of Multiclass Classification to Binary Classification.
Performs reduction using one against all strategy.
@@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
.. versionadded:: 2.0.0
"""
- classifier = Param(Params._dummy(), "classifier", "base binary classifier")
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
@@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
- @since("2.0.0")
- def setClassifier(self, value):
- """
- Sets the value of :py:attr:`classifier`.
-
- .. note:: Only LogisticRegression and NaiveBayes are supported now.
- """
- self._set(classifier=value)
- return self
-
- @since("2.0.0")
- def getClassifier(self):
- """
- Gets the value of classifier or its default value.
- """
- return self.getOrDefault(self.classifier)
-
def _fit(self, dataset):
labelCol = self.getLabelCol()
featuresCol = self.getFeaturesCol()
@@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
newOvr.setClassifier(self.getClassifier().copy(extra))
return newOvr
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java OneVsRest, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ featuresCol = java_stage.getFeaturesCol()
+ labelCol = java_stage.getLabelCol()
+ predictionCol = java_stage.getPredictionCol()
+ classifier = JavaParams._from_java(java_stage.getClassifier())
+ py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
+ classifier=classifier)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java OneVsRest. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
+ self.uid)
+ _java_obj.setClassifier(self.getClassifier()._to_java())
+ _java_obj.setFeaturesCol(self.getFeaturesCol())
+ _java_obj.setLabelCol(self.getLabelCol())
+ _java_obj.setPredictionCol(self.getPredictionCol())
+ return _java_obj
-class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+
+class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
"""
Model fitted by OneVsRest.
This stores the models resulting from training k binary classifiers: one for each class.
@@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
newModel.models = [model.copy(extra) for model in self.models]
return newModel
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java OneVsRestModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ featuresCol = java_stage.getFeaturesCol()
+ labelCol = java_stage.getLabelCol()
+ predictionCol = java_stage.getPredictionCol()
+ classifier = JavaParams._from_java(java_stage.getClassifier())
+ models = [JavaParams._from_java(model) for model in java_stage.models()]
+ py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
+ .setFeaturesCol(featuresCol).setClassifier(classifier)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+ java_models = [model._to_java() for model in self.models]
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
+ self.uid, java_models)
+ _java_obj.set("classifier", self.getClassifier()._to_java())
+ _java_obj.set("featuresCol", self.getFeaturesCol())
+ _java_obj.set("labelCol", self.getLabelCol())
+ _java_obj.set("predictionCol", self.getPredictionCol())
+ return _java_obj
+
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a7a9868bac..9d6ff47b54 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -43,7 +43,8 @@ import tempfile
import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest
+from pyspark.ml.classification import (
+ LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
@@ -881,6 +882,28 @@ class OneVsRestTests(PySparkTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr)
+ model = ovr.fit(df)
+ ovrPath = temp_path + "/ovr"
+ ovr.save(ovrPath)
+ loadedOvr = OneVsRest.load(ovrPath)
+ self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
+ self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
+ self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid)
+ modelPath = temp_path + "/ovrModel"
+ model.save(modelPath)
+ loadedModel = OneVsRestModel.load(modelPath)
+ for m, n in zip(model.models, loadedModel.models):
+ self.assertEqual(m.uid, n.uid)
+
class HashingTFTest(PySparkTestCase):