aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e93a4e157b..5fcfa9e61f 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -209,6 +209,11 @@ class ParamTests(PySparkTestCase):
self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
self.assertTrue(maxIter.parent == testParams.uid)
+ def test_hasparam(self):
+ testParams = TestParams()
+ self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
+ self.assertFalse(testParams.hasParam("notAParameter"))
+
def test_params(self):
testParams = TestParams()
maxIter = testParams.maxIter
@@ -218,7 +223,7 @@ class ParamTests(PySparkTestCase):
params = testParams.params
self.assertEqual(params, [inputCol, maxIter, seed])
- self.assertTrue(testParams.hasParam(maxIter))
+ self.assertTrue(testParams.hasParam(maxIter.name))
self.assertTrue(testParams.hasDefault(maxIter))
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
@@ -227,7 +232,7 @@ class ParamTests(PySparkTestCase):
self.assertTrue(testParams.isSet(maxIter))
self.assertEqual(testParams.getMaxIter(), 100)
- self.assertTrue(testParams.hasParam(inputCol))
+ self.assertTrue(testParams.hasParam(inputCol.name))
self.assertFalse(testParams.hasDefault(inputCol))
self.assertFalse(testParams.isSet(inputCol))
self.assertFalse(testParams.isDefined(inputCol))