diff options
author | GayathriMurali <gayathri.m.softie@gmail.com> | 2016-03-24 19:20:49 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-24 19:20:49 -0700 |
commit | 0874ff3aade705a97f174b642c5db01711d214b3 (patch) | |
tree | a32030e8bb7a8d2ea9d5e91f61d944d21f7b6623 /python/pyspark/ml/tests.py | |
parent | 585097716c1979ea538ef733cf33225ef7be06f5 (diff) | |
download | spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.gz spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.bz2 spark-0874ff3aade705a97f174b642c5db01711d214b3.zip |
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import
## What changes were proposed in this pull request?
Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests.
## How was this patch tested?
Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor.
Author: GayathriMurali <gayathri.m.softie@gmail.com>
Closes #11892 from GayathriMurali/SPARK-13949.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 40 |
1 files changed, 38 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2fa5da7738..224232ed7f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -42,13 +42,13 @@ import tempfile import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.regression import LinearRegression +from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaWrapper @@ -655,6 +655,42 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + class HasThrowableProperty(Params): |