diff options
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 850d775db0..d51b80e16c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval, HasSeed): + TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model.transform(test1).head().prediction 1.0 + >>> dtc_path = temp_path + "/dtc" + >>> dt.save(dtc_path) + >>> dt2 = DecisionTreeClassifier.load(dtc_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtc_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + .. versionadded:: 1.4.0 """ @@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @inherit_doc -class DecisionTreeClassificationModel(DecisionTreeModel): +class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeClassifier. |