aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorGayathriMurali <gayathri.m.softie@gmail.com>2016-03-24 19:20:49 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 19:20:49 -0700
commit0874ff3aade705a97f174b642c5db01711d214b3 (patch)
treea32030e8bb7a8d2ea9d5e91f61d944d21f7b6623 /python/pyspark/ml/classification.py
parent585097716c1979ea538ef733cf33225ef7be06f5 (diff)
downloadspark-0874ff3aade705a97f174b642c5db01711d214b3.tar.gz
spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.bz2
spark-0874ff3aade705a97f174b642c5db01711d214b3.zip
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import
## What changes were proposed in this pull request? Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests. ## How was this patch tested? Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor. Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #11892 from GayathriMurali/SPARK-13949.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 850d775db0..d51b80e16c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
- TreeClassifierParams, HasCheckpointInterval, HasSeed):
+ TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
+ JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for classification.
@@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> model.transform(test1).head().prediction
1.0
+ >>> dtc_path = temp_path + "/dtc"
+ >>> dt.save(dtc_path)
+ >>> dt2 = DecisionTreeClassifier.load(dtc_path)
+ >>> dt2.getMaxDepth()
+ 2
+ >>> model_path = temp_path + "/dtc_model"
+ >>> model.save(model_path)
+ >>> model2 = DecisionTreeClassificationModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
+
.. versionadded:: 1.4.0
"""
@@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@inherit_doc
-class DecisionTreeClassificationModel(DecisionTreeModel):
+class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.