aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-20 15:34:34 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-20 15:34:34 -0700
commit454a00df2a43176cb774cad7277934a775618db1 (patch)
tree8eadef4005a35407a6711f28a808711057f1df92 /python/pyspark
parent811a5247227b5c68e6cd74c0a88d809862184507 (diff)
downloadspark-454a00df2a43176cb774cad7277934a775618db1.tar.gz
spark-454a00df2a43176cb774cad7277934a775618db1.tar.bz2
spark-454a00df2a43176cb774cad7277934a775618db1.zip
[SPARK-13993][PYSPARK] Add pyspark Rformula/RforumlaModel save/load
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13993 ## How was this patch tested? doctest Author: Xusen Yin <yinxusen@gmail.com> Closes #11807 from yinxusen/SPARK-13993.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/feature.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 5025493c42..3182faac0d 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2360,7 +2360,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -2376,7 +2376,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
... (0.0, 0.0, "a")
... ], ["y", "x", "s"])
>>> rf = RFormula(formula="y ~ x + s")
- >>> rf.fit(df).transform(df).show()
+ >>> model = rf.fit(df)
+ >>> model.transform(df).show()
+---+---+---+---------+-----+
| y| x| s| features|label|
+---+---+---+---------+-----+
@@ -2394,6 +2395,29 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
|0.0|0.0| a| [0.0]| 0.0|
+---+---+---+--------+-----+
...
+ >>> rFormulaPath = temp_path + "/rFormula"
+ >>> rf.save(rFormulaPath)
+ >>> loadedRF = RFormula.load(rFormulaPath)
+ >>> loadedRF.getFormula() == rf.getFormula()
+ True
+ >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()
+ True
+ >>> loadedRF.getLabelCol() == rf.getLabelCol()
+ True
+ >>> modelPath = temp_path + "/rFormulaModel"
+ >>> model.save(modelPath)
+ >>> loadedModel = RFormulaModel.load(modelPath)
+ >>> loadedModel.uid == model.uid
+ True
+ >>> loadedModel.transform(df).show()
+ +---+---+---+---------+-----+
+ | y| x| s| features|label|
+ +---+---+---+---------+-----+
+ |1.0|1.0| a|[1.0,1.0]| 1.0|
+ |0.0|2.0| b|[2.0,0.0]| 0.0|
+ |0.0|0.0| a|[0.0,1.0]| 0.0|
+ +---+---+---+---------+-----+
+ ...
.. versionadded:: 1.5.0
"""
@@ -2439,7 +2463,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
return RFormulaModel(java_model)
-class RFormulaModel(JavaModel):
+class RFormulaModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental