diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-03-30 15:47:01 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-30 15:47:01 -0700 |
commit | f301df37cb63aeecf48077ae56351538e6eeeeb7 (patch) | |
tree | 4963ee32b035e2d9a2585ba051e418a7c9dd0331 | |
parent | 5dc948e8125fd27646a7f1e8991948a45b3f9c50 (diff) | |
download | spark-f301df37cb63aeecf48077ae56351538e6eeeeb7.tar.gz spark-f301df37cb63aeecf48077ae56351538e6eeeeb7.tar.bz2 spark-f301df37cb63aeecf48077ae56351538e6eeeeb7.zip |
[SPARK-14152][ML][PYSPARK] MultilayerPerceptronClassifier supports save/load for Python API
## What changes were proposed in this pull request?
```MultilayerPerceptronClassifier``` supports save/load for Python API.
## How was this patch tested?
doctest.
cc mengxr jkbradley yinxusen
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #11952 from yanboliang/spark-14152.
-rw-r--r-- | python/pyspark/ml/classification.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d51b80e16c..07cafa0993 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -762,7 +762,7 @@ class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed): + HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -792,6 +792,18 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, |[0.0,0.0]| 0.0| +---------+----------+ ... + >>> mlp_path = temp_path + "/mlp" + >>> mlp.save(mlp_path) + >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path) + >>> mlp2.getBlockSize() + 1 + >>> model_path = temp_path + "/mlp_model" + >>> model.save(model_path) + >>> model2 = MultilayerPerceptronClassificationModel.load(model_path) + >>> model.layers == model2.layers + True + >>> model.weights == model2.weights + True .. versionadded:: 1.6.0 """ @@ -869,7 +881,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, return self.getOrDefault(self.blockSize) -class MultilayerPerceptronClassificationModel(JavaModel): +class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by MultilayerPerceptronClassifier. |