From d8813fa043e8b8f7cbf6921d4c7ec889634f7abd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 8 Mar 2016 17:34:25 -0800 Subject: [SPARK-13625][PYSPARK][ML] Added a check to see if an attribute is a property when getting param list ## What changes were proposed in this pull request? Added a check in pyspark.ml.param.Param.params() to see if an attribute is a property (decorated with `property`) before checking if it is a `Param` instance. This prevents the property from being invoked to 'get' this attribute, which could possibly cause an error. ## How was this patch tested? Added a test case with a class has a property that will raise an error when invoked and then call`Param.params` to verify that the property is not invoked, but still able to find another property in the class. Also ran pyspark-ml test before fix that will trigger an error, and again after the fix to verify that the error was resolved and the method was working properly. Author: Bryan Cutler Closes #11476 from BryanCutler/pyspark-ml-property-attr-SPARK-13625. --- python/pyspark/ml/param/__init__.py | 3 ++- python/pyspark/ml/tests.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index bbf83f0310..c0f0a71eb6 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -109,7 +109,8 @@ class Params(Identifiable): """ if self._params is None: self._params = list(filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"])) + [getattr(self, x) for x in dir(self) if x != "params" and + not isinstance(getattr(type(self), x, None), property)])) return self._params @since("1.4.0") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 8182fcfb4e..4da9a373e9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -271,6 +271,12 @@ class ParamTests(PySparkTestCase): # Check that a different class has a different seed self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + def test_param_property_error(self): + param_store = HasThrowableProperty() + self.assertRaises(RuntimeError, lambda: param_store.test_property) + params = param_store.params # should not invoke the property 'test_property' + self.assertEqual(len(params), 1) + class FeatureTests(PySparkTestCase): @@ -494,6 +500,17 @@ class PersistenceTest(PySparkTestCase): pass +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: -- cgit v1.2.3