aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-15 23:49:42 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-15 23:49:42 -0700
commit57cd1e86d1d450f85fc9e296aff498a940452113 (patch)
tree10e973e431fc3ca3e92c823eed077dae5772f5f5 /python/pyspark/ml/tests.py
parent52c3439a8a107ce1fc10e4f0b59fd7881e851622 (diff)
downloadspark-57cd1e86d1d450f85fc9e296aff498a940452113.tar.gz
spark-57cd1e86d1d450f85fc9e296aff498a940452113.tar.bz2
spark-57cd1e86d1d450f85fc9e296aff498a940452113.zip
[SPARK-6893][ML] default pipeline parameter handling in python
Same as #5431 but for Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5534 from mengxr/SPARK-6893 and squashes the following commits: d3b519b [Xiangrui Meng] address comments ebaccc6 [Xiangrui Meng] style update fce244e [Xiangrui Meng] update explainParams with test 4d6b07a [Xiangrui Meng] add tests 5294500 [Xiangrui Meng] update default param handling in python
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index b627c2b4e9..3a42bcf723 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,6 +33,7 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
@@ -46,7 +47,7 @@ class MockTransformer(Transformer):
def __init__(self):
super(MockTransformer, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
@@ -62,7 +63,7 @@ class MockEstimator(Estimator):
def __init__(self):
super(MockEstimator, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
self.model = None
@@ -111,5 +112,52 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
+class TestParams(HasMaxIter, HasInputCol):
+ """
+ A subclass of Params mixed with HasMaxIter and HasInputCol.
+ """
+
+ def __init__(self):
+ super(TestParams, self).__init__()
+ self._setDefault(maxIter=10)
+
+
+class ParamTests(PySparkTestCase):
+
+ def test_param(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations")
+ self.assertTrue(maxIter.parent is testParams)
+
+ def test_params(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ inputCol = testParams.inputCol
+
+ params = testParams.params
+ self.assertEqual(params, [inputCol, maxIter])
+
+ self.assertTrue(testParams.hasDefault(maxIter))
+ self.assertFalse(testParams.isSet(maxIter))
+ self.assertTrue(testParams.isDefined(maxIter))
+ self.assertEqual(testParams.getMaxIter(), 10)
+ testParams.setMaxIter(100)
+ self.assertTrue(testParams.isSet(maxIter))
+ self.assertEquals(testParams.getMaxIter(), 100)
+
+ self.assertFalse(testParams.hasDefault(inputCol))
+ self.assertFalse(testParams.isSet(inputCol))
+ self.assertFalse(testParams.isDefined(inputCol))
+ with self.assertRaises(KeyError):
+ testParams.getInputCol()
+
+ self.assertEquals(
+ testParams.explainParams(),
+ "\n".join(["inputCol: input column name (undefined)",
+ "maxIter: max number of iterations (default: 10, current: 100)"]))
+
+
if __name__ == "__main__":
unittest.main()