diff options
author | GayathriMurali <gayathri.m.softie@gmail.com> | 2016-03-24 19:20:49 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-24 19:20:49 -0700 |
commit | 0874ff3aade705a97f174b642c5db01711d214b3 (patch) | |
tree | a32030e8bb7a8d2ea9d5e91f61d944d21f7b6623 /python/pyspark/ml/regression.py | |
parent | 585097716c1979ea538ef733cf33225ef7be06f5 (diff) | |
download | spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.gz spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.bz2 spark-0874ff3aade705a97f174b642c5db01711d214b3.zip |
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import
## What changes were proposed in this pull request?
Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests.
## How was this patch tested?
Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor.
Author: GayathriMurali <gayathri.m.softie@gmail.com>
Closes #11892 from GayathriMurali/SPARK-13949.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 59d4fe3cf4..37648549de 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, - HasSeed): + HasSeed, JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> dtr_path = temp_path + "/dtr" + >>> dt.save(dtr_path) + >>> dt2 = DecisionTreeRegressor.load(dtr_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtr_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeRegressionModel.load(model_path) + >>> model.numNodes == model2.numNodes + True + >>> model.depth == model2.depth + True .. versionadded:: 1.4.0 """ @@ -498,7 +510,7 @@ class TreeEnsembleModels(JavaModel): @inherit_doc -class DecisionTreeRegressionModel(DecisionTreeModel): +class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeRegressor. |