aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py52
1 files changed, 50 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index b627c2b4e9..3a42bcf723 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,6 +33,7 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
@@ -46,7 +47,7 @@ class MockTransformer(Transformer):
def __init__(self):
super(MockTransformer, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
@@ -62,7 +63,7 @@ class MockEstimator(Estimator):
def __init__(self):
super(MockEstimator, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
self.model = None
@@ -111,5 +112,52 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
+class TestParams(HasMaxIter, HasInputCol):
+ """
+ A subclass of Params mixed with HasMaxIter and HasInputCol.
+ """
+
+ def __init__(self):
+ super(TestParams, self).__init__()
+ self._setDefault(maxIter=10)
+
+
+class ParamTests(PySparkTestCase):
+
+ def test_param(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations")
+ self.assertTrue(maxIter.parent is testParams)
+
+ def test_params(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ inputCol = testParams.inputCol
+
+ params = testParams.params
+ self.assertEqual(params, [inputCol, maxIter])
+
+ self.assertTrue(testParams.hasDefault(maxIter))
+ self.assertFalse(testParams.isSet(maxIter))
+ self.assertTrue(testParams.isDefined(maxIter))
+ self.assertEqual(testParams.getMaxIter(), 10)
+ testParams.setMaxIter(100)
+ self.assertTrue(testParams.isSet(maxIter))
+ self.assertEquals(testParams.getMaxIter(), 100)
+
+ self.assertFalse(testParams.hasDefault(inputCol))
+ self.assertFalse(testParams.isSet(inputCol))
+ self.assertFalse(testParams.isDefined(inputCol))
+ with self.assertRaises(KeyError):
+ testParams.getInputCol()
+
+ self.assertEquals(
+ testParams.explainParams(),
+ "\n".join(["inputCol: input column name (undefined)",
+ "maxIter: max number of iterations (default: 10, current: 100)"]))
+
+
if __name__ == "__main__":
unittest.main()