aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorKai Jiang <jiangkai@gmail.com>2016-04-08 10:39:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-08 10:39:12 -0700
commite5d8d6e09cad304e353c96f9408fb9f799348827 (patch)
treeed873040f638e1e98e772abd3650af54b2bff9fa /python
parenta9b630f42ac0c6be3437f206beddaf0ef737f5c8 (diff)
downloadspark-e5d8d6e09cad304e353c96f9408fb9f799348827.tar.gz
spark-e5d8d6e09cad304e353c96f9408fb9f799348827.tar.bz2
spark-e5d8d6e09cad304e353c96f9408fb9f799348827.zip
[SPARK-14373][PYSPARK] PySpark RandomForestClassifier, Regressor support export/import
## What changes were proposed in this pull request? supporting `RandomForest{Classifier, Regressor}` save/load for Python API. [JIRA](https://issues.apache.org/jira/browse/SPARK-14373) ## How was this patch tested? doctest Author: Kai Jiang <jiangkai@gmail.com> Closes #12238 from vectorijk/spark-14373.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py15
-rw-r--r--python/pyspark/ml/regression.py15
2 files changed, 26 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index be7f9ea9ef..d98919b3c6 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -621,7 +621,8 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
HasRawPredictionCol, HasProbabilityCol,
- RandomForestParams, TreeClassifierParams, HasCheckpointInterval):
+ RandomForestParams, TreeClassifierParams, HasCheckpointInterval,
+ JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for classification.
@@ -655,6 +656,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> rfc_path = temp_path + "/rfc"
+ >>> rf.save(rfc_path)
+ >>> rf2 = RandomForestClassifier.load(rfc_path)
+ >>> rf2.getNumTrees()
+ 3
+ >>> model_path = temp_path + "/rfc_model"
+ >>> model.save(model_path)
+ >>> model2 = RandomForestClassificationModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
.. versionadded:: 1.4.0
"""
@@ -703,7 +714,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return RandomForestClassificationModel(java_model)
-class RandomForestClassificationModel(TreeEnsembleModels):
+class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 6cd1b4bf3a..00a6a0de90 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -782,7 +782,8 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
@inherit_doc
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
- RandomForestParams, TreeRegressorParams, HasCheckpointInterval):
+ RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
+ JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for regression.
@@ -805,6 +806,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
+ >>> rfr_path = temp_path + "/rfr"
+ >>> rf.save(rfr_path)
+ >>> rf2 = RandomForestRegressor.load(rfr_path)
+ >>> rf2.getNumTrees()
+ 2
+ >>> model_path = temp_path + "/rfr_model"
+ >>> model.save(model_path)
+ >>> model2 = RandomForestRegressionModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
.. versionadded:: 1.4.0
"""
@@ -854,7 +865,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return RandomForestRegressionModel(java_model)
-class RandomForestRegressionModel(TreeEnsembleModels):
+class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestRegressor.