diff options
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index d5dd6d43c2..78ec96af8a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,6 +41,7 @@ else: from shutil import rmtree import tempfile import numpy as np +import inspect from pyspark import keyword_only from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer @@ -54,6 +55,7 @@ from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams +from pyspark.mllib.common import _java2py from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -1026,6 +1028,52 @@ class ALSTest(PySparkTestCase): self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") +class DefaultValuesTests(PySparkTestCase): + """ + Test :py:class:`JavaParams` classes to see if their default Param values match + those in their Scala counterparts. + """ + + def check_params(self, py_stage): + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + for p in py_stage.params: + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + return # Random seeds between Spark and PySpark are different + java_default =\ + _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + self.assertEqual(java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + + def test_java_params(self): + import pyspark.ml.feature + import pyspark.ml.classification + import pyspark.ml.clustering + import pyspark.ml.pipeline + import pyspark.ml.recommendation + import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, + pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + for module in modules: + for name, cls in inspect.getmembers(module, inspect.isclass): + if not name.endswith('Model') and issubclass(cls, JavaParams)\ + and not inspect.isabstract(cls): + self.check_params(cls()) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: |