diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-02-11 16:42:44 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-11 16:42:44 -0800 |
commit | b35467388612167f0bc3d17142c21a406f6c620d (patch) | |
tree | e4b9297752b83a4377aa08e286f9f968b11e4a1c | |
parent | 30e00955663dfe6079effe4465bbcecedb5d93b9 (diff) | |
download | spark-b35467388612167f0bc3d17142c21a406f6c620d.tar.gz spark-b35467388612167f0bc3d17142c21a406f6c620d.tar.bz2 spark-b35467388612167f0bc3d17142c21a406f6c620d.zip |
[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error
Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality.
In Python:
```python
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
print nb.hasParam("smoothing")
print nb.hasParam("notAParam")
```
produces:
> True
> AttributeError: 'NaiveBayes' object has no attribute 'notAParam'
However, in Scala:
```scala
import org.apache.spark.ml.classification.NaiveBayes
val nb = new NaiveBayes()
nb.hasParam("smoothing")
nb.hasParam("notAParam")
```
produces:
> true
> false
cc holdenk
Author: sethah <seth.hendrickson16@gmail.com>
Closes #10962 from sethah/SPARK-13047.
-rw-r--r-- | python/pyspark/ml/param/__init__.py | 7 | ||||
-rw-r--r-- | python/pyspark/ml/tests.py | 9 |
2 files changed, 12 insertions, 4 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ea86d6aeb8..bbf83f0310 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -179,8 +179,11 @@ class Params(Identifiable): Tests whether this instance contains a param with a given (string) name. """ - param = self._resolveParam(paramName) - return param in self.params + if isinstance(paramName, str): + p = getattr(self, paramName, None) + return isinstance(p, Param) + else: + raise TypeError("hasParam(): paramName must be a string") @since("1.4.0") def getOrDefault(self, param): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e93a4e157b..5fcfa9e61f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -209,6 +209,11 @@ class ParamTests(PySparkTestCase): self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + def test_params(self): testParams = TestParams() maxIter = testParams.maxIter @@ -218,7 +223,7 @@ class ParamTests(PySparkTestCase): params = testParams.params self.assertEqual(params, [inputCol, maxIter, seed]) - self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasParam(maxIter.name)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -227,7 +232,7 @@ class ParamTests(PySparkTestCase): self.assertTrue(testParams.isSet(maxIter)) self.assertEqual(testParams.getMaxIter(), 100) - self.assertTrue(testParams.hasParam(inputCol)) + self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) |