aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorGayathriMurali <gayathri.m.softie@gmail.com>2016-03-24 19:20:49 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 19:20:49 -0700
commit0874ff3aade705a97f174b642c5db01711d214b3 (patch)
treea32030e8bb7a8d2ea9d5e91f61d944d21f7b6623 /python/pyspark/ml/tests.py
parent585097716c1979ea538ef733cf33225ef7be06f5 (diff)
downloadspark-0874ff3aade705a97f174b642c5db01711d214b3.tar.gz
spark-0874ff3aade705a97f174b642c5db01711d214b3.tar.bz2
spark-0874ff3aade705a97f174b642c5db01711d214b3.zip
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import
## What changes were proposed in this pull request? Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests. ## How was this patch tested? Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor. Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #11892 from GayathriMurali/SPARK-13949.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2fa5da7738..224232ed7f 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -42,13 +42,13 @@ import tempfile
import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.regression import LinearRegression
+from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaWrapper
@@ -655,6 +655,42 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_decisiontree_classifier(self):
+ dt = DecisionTreeClassifier(maxDepth=1)
+ path = tempfile.mkdtemp()
+ dtc_path = path + "/dtc"
+ dt.save(dtc_path)
+ dt2 = DecisionTreeClassifier.load(dtc_path)
+ self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+ "Loaded DecisionTreeClassifier instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (dt2.uid, dt2.maxDepth.parent))
+ self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+ "Loaded DecisionTreeClassifier instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def test_decisiontree_regressor(self):
+ dt = DecisionTreeRegressor(maxDepth=1)
+ path = tempfile.mkdtemp()
+ dtr_path = path + "/dtr"
+ dt.save(dtr_path)
+ dt2 = DecisionTreeClassifier.load(dtr_path)
+ self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+ "Loaded DecisionTreeRegressor instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (dt2.uid, dt2.maxDepth.parent))
+ self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+ "Loaded DecisionTreeRegressor instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
class HasThrowableProperty(Params):