aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-30 15:47:01 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-30 15:47:01 -0700
commitf301df37cb63aeecf48077ae56351538e6eeeeb7 (patch)
tree4963ee32b035e2d9a2585ba051e418a7c9dd0331 /python
parent5dc948e8125fd27646a7f1e8991948a45b3f9c50 (diff)
downloadspark-f301df37cb63aeecf48077ae56351538e6eeeeb7.tar.gz
spark-f301df37cb63aeecf48077ae56351538e6eeeeb7.tar.bz2
spark-f301df37cb63aeecf48077ae56351538e6eeeeb7.zip
[SPARK-14152][ML][PYSPARK] MultilayerPerceptronClassifier supports save/load for Python API
## What changes were proposed in this pull request? ```MultilayerPerceptronClassifier``` supports save/load for Python API. ## How was this patch tested? doctest. cc mengxr jkbradley yinxusen Author: Yanbo Liang <ybliang8@gmail.com> Closes #11952 from yanboliang/spark-14152.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index d51b80e16c..07cafa0993 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -762,7 +762,7 @@ class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
@inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasMaxIter, HasTol, HasSeed):
+ HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable):
"""
Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax.
@@ -792,6 +792,18 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
|[0.0,0.0]| 0.0|
+---------+----------+
...
+ >>> mlp_path = temp_path + "/mlp"
+ >>> mlp.save(mlp_path)
+ >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
+ >>> mlp2.getBlockSize()
+ 1
+ >>> model_path = temp_path + "/mlp_model"
+ >>> model.save(model_path)
+ >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
+ >>> model.layers == model2.layers
+ True
+ >>> model.weights == model2.weights
+ True
.. versionadded:: 1.6.0
"""
@@ -869,7 +881,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
return self.getOrDefault(self.blockSize)
-class MultilayerPerceptronClassificationModel(JavaModel):
+class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by MultilayerPerceptronClassifier.