aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xpython/pyspark/ml/tests.py69
1 files changed, 46 insertions, 23 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a7c93ac802..4358175a57 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
+ def _compare_params(self, m1, m2, param):
+ """
+ Compare 2 ML Params instances for the given param, and assert both have the same param value
+ and parent. The param must be a parameter of m1.
+ """
+ # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
+ if m1.isDefined(param):
+ paramValue1 = m1.getOrDefault(param)
+ paramValue2 = m2.getOrDefault(m2.getParam(param.name))
+ if isinstance(paramValue1, Params):
+ self._compare_pipelines(paramValue1, paramValue2)
+ else:
+ self.assertEqual(paramValue1, paramValue2) # for general types param
+ # Assert parents are equal
+ self.assertEqual(param.parent, m2.getParam(param.name).parent)
+ else:
+ # If m1 is not defined param, then m2 should not, too. See SPARK-14931.
+ self.assertFalse(m2.isDefined(m2.getParam(param.name)))
+
def _compare_pipelines(self, m1, m2):
"""
Compare 2 ML types, asserting that they are equivalent.
This currently supports:
- basic types
- Pipeline, PipelineModel
+ - OneVsRest, OneVsRestModel
This checks:
- uid
- type
@@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase):
if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
- self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
- self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ self._compare_params(m1, m2, p)
elif isinstance(m1, Pipeline):
self.assertEqual(len(m1.getStages()), len(m2.getStages()))
for s1, s2 in zip(m1.getStages(), m2.getStages()):
@@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase):
self.assertEqual(len(m1.stages), len(m2.stages))
for s1, s2 in zip(m1.stages, m2.stages):
self._compare_pipelines(s1, s2)
+ elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
+ for p in m1.params:
+ self._compare_params(m1, m2, p)
+ if isinstance(m1, OneVsRestModel):
+ self.assertEqual(len(m1.models), len(m2.models))
+ for x, y in zip(m1.models, m2.models):
+ self._compare_pipelines(x, y)
else:
raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
@@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
+ def test_onevsrest(self):
+ temp_path = tempfile.mkdtemp()
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))] * 10,
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr)
+ model = ovr.fit(df)
+ ovrPath = temp_path + "/ovr"
+ ovr.save(ovrPath)
+ loadedOvr = OneVsRest.load(ovrPath)
+ self._compare_pipelines(ovr, loadedOvr)
+ modelPath = temp_path + "/ovrModel"
+ model.save(modelPath)
+ loadedModel = OneVsRestModel.load(modelPath)
+ self._compare_pipelines(model, loadedModel)
+
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
@@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])
- def test_save_load(self):
- temp_path = tempfile.mkdtemp()
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
- (1.0, Vectors.sparse(2, [], [])),
- (2.0, Vectors.dense(0.5, 0.5))],
- ["label", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01)
- ovr = OneVsRest(classifier=lr)
- model = ovr.fit(df)
- ovrPath = temp_path + "/ovr"
- ovr.save(ovrPath)
- loadedOvr = OneVsRest.load(ovrPath)
- self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
- self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
- self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid)
- modelPath = temp_path + "/ovrModel"
- model.save(modelPath)
- loadedModel = OneVsRestModel.load(modelPath)
- for m, n in zip(model.models, loadedModel.models):
- self.assertEqual(m.uid, n.uid)
-
class HashingTFTest(SparkSessionTestCase):