aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index d5dd6d43c2..78ec96af8a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -41,6 +41,7 @@ else:
from shutil import rmtree
import tempfile
import numpy as np
+import inspect
from pyspark import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
@@ -54,6 +55,7 @@ from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
+from pyspark.mllib.common import _java2py
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -1026,6 +1028,52 @@ class ALSTest(PySparkTestCase):
self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY")
+class DefaultValuesTests(PySparkTestCase):
+ """
+ Test :py:class:`JavaParams` classes to see if their default Param values match
+ those in their Scala counterparts.
+ """
+
+ def check_params(self, py_stage):
+ if not hasattr(py_stage, "_to_java"):
+ return
+ java_stage = py_stage._to_java()
+ if java_stage is None:
+ return
+ for p in py_stage.params:
+ java_param = java_stage.getParam(p.name)
+ py_has_default = py_stage.hasDefault(p)
+ java_has_default = java_stage.hasDefault(java_param)
+ self.assertEqual(py_has_default, java_has_default,
+ "Default value mismatch of param %s for Params %s"
+ % (p.name, str(py_stage)))
+ if py_has_default:
+ if p.name == "seed":
+ return # Random seeds between Spark and PySpark are different
+ java_default =\
+ _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param))
+ py_stage._clear(p)
+ py_default = py_stage.getOrDefault(p)
+ self.assertEqual(java_default, py_default,
+ "Java default %s != python default %s of param %s for Params %s"
+ % (str(java_default), str(py_default), p.name, str(py_stage)))
+
+ def test_java_params(self):
+ import pyspark.ml.feature
+ import pyspark.ml.classification
+ import pyspark.ml.clustering
+ import pyspark.ml.pipeline
+ import pyspark.ml.recommendation
+ import pyspark.ml.regression
+ modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering,
+ pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression]
+ for module in modules:
+ for name, cls in inspect.getmembers(module, inspect.isclass):
+ if not name.endswith('Model') and issubclass(cls, JavaParams)\
+ and not inspect.isabstract(cls):
+ self.check_params(cls())
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner: