diff options
author | Kai Jiang <jiangkai@gmail.com> | 2016-04-08 10:39:12 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-08 10:39:12 -0700 |
commit | e5d8d6e09cad304e353c96f9408fb9f799348827 (patch) | |
tree | ed873040f638e1e98e772abd3650af54b2bff9fa /python | |
parent | a9b630f42ac0c6be3437f206beddaf0ef737f5c8 (diff) | |
download | spark-e5d8d6e09cad304e353c96f9408fb9f799348827.tar.gz spark-e5d8d6e09cad304e353c96f9408fb9f799348827.tar.bz2 spark-e5d8d6e09cad304e353c96f9408fb9f799348827.zip |
[SPARK-14373][PYSPARK] PySpark RandomForestClassifier, Regressor support export/import
## What changes were proposed in this pull request?
supporting `RandomForest{Classifier, Regressor}` save/load for Python API.
[JIRA](https://issues.apache.org/jira/browse/SPARK-14373)
## How was this patch tested?
doctest
Author: Kai Jiang <jiangkai@gmail.com>
Closes #12238 from vectorijk/spark-14373.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 15 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 15 |
2 files changed, 26 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index be7f9ea9ef..d98919b3c6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -621,7 +621,8 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, HasRawPredictionCol, HasProbabilityCol, - RandomForestParams, TreeClassifierParams, HasCheckpointInterval): + RandomForestParams, TreeClassifierParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for classification. @@ -655,6 +656,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> rfc_path = temp_path + "/rfc" + >>> rf.save(rfc_path) + >>> rf2 = RandomForestClassifier.load(rfc_path) + >>> rf2.getNumTrees() + 3 + >>> model_path = temp_path + "/rfc_model" + >>> model.save(model_path) + >>> model2 = RandomForestClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True .. versionadded:: 1.4.0 """ @@ -703,7 +714,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return RandomForestClassificationModel(java_model) -class RandomForestClassificationModel(TreeEnsembleModels): +class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6cd1b4bf3a..00a6a0de90 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -782,7 +782,8 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - RandomForestParams, TreeRegressorParams, HasCheckpointInterval): + RandomForestParams, TreeRegressorParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for regression. @@ -805,6 +806,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + >>> rfr_path = temp_path + "/rfr" + >>> rf.save(rfr_path) + >>> rf2 = RandomForestRegressor.load(rfr_path) + >>> rf2.getNumTrees() + 2 + >>> model_path = temp_path + "/rfr_model" + >>> model.save(model_path) + >>> model2 = RandomForestRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True .. versionadded:: 1.4.0 """ @@ -854,7 +865,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return RandomForestRegressionModel(java_model) -class RandomForestRegressionModel(TreeEnsembleModels): +class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestRegressor. |