From 0874ff3aade705a97f174b642c5db01711d214b3 Mon Sep 17 00:00:00 2001 From: GayathriMurali Date: Thu, 24 Mar 2016 19:20:49 -0700 Subject: [SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import ## What changes were proposed in this pull request? Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests. ## How was this patch tested? Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor. Author: GayathriMurali Closes #11892 from GayathriMurali/SPARK-13949. --- python/pyspark/ml/classification.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'python/pyspark/ml/classification.py') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 850d775db0..d51b80e16c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval, HasSeed): + TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model.transform(test1).head().prediction 1.0 + >>> dtc_path = temp_path + "/dtc" + >>> dt.save(dtc_path) + >>> dt2 = DecisionTreeClassifier.load(dtc_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtc_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + .. versionadded:: 1.4.0 """ @@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @inherit_doc -class DecisionTreeClassificationModel(DecisionTreeModel): +class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeClassifier. -- cgit v1.2.3