diff options
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e93a4e157b..5fcfa9e61f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -209,6 +209,11 @@ class ParamTests(PySparkTestCase): self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + def test_params(self): testParams = TestParams() maxIter = testParams.maxIter @@ -218,7 +223,7 @@ class ParamTests(PySparkTestCase): params = testParams.params self.assertEqual(params, [inputCol, maxIter, seed]) - self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasParam(maxIter.name)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -227,7 +232,7 @@ class ParamTests(PySparkTestCase): self.assertTrue(testParams.isSet(maxIter)) self.assertEqual(testParams.getMaxIter(), 100) - self.assertTrue(testParams.hasParam(inputCol)) + self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) |