From 130b8d07b8eb08f2ad522081a95032b90247094d Mon Sep 17 00:00:00 2001 From: yinxusen Date: Fri, 27 May 2016 13:18:29 -0700 Subject: [SPARK-15008][ML][PYSPARK] Add integration test for OneVsRest ## What changes were proposed in this pull request? 1. Add `_transfer_param_map_to/from_java` for OneVsRest; 2. Add `_compare_params` in ml/tests.py to help compare params. 3. Add `test_onevsrest` as the integration test for OneVsRest. ## How was this patch tested? Python unit test. Author: yinxusen Closes #12875 from yinxusen/SPARK-15008. --- python/pyspark/ml/tests.py | 69 ++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 23 deletions(-) (limited to 'python') diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a7c93ac802..4358175a57 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + def _compare_pipelines(self, m1, m2): """ Compare 2 ML types, asserting that they are equivalent. This currently supports: - basic types - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel This checks: - uid - type @@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase): if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) + self._compare_params(m1, m2, p) elif isinstance(m1, Pipeline): self.assertEqual(len(m1.getStages()), len(m2.getStages())) for s1, s2 in zip(m1.getStages(), m2.getStages()): @@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase): self.assertEqual(len(m1.stages), len(m2.stages)) for s1, s2 in zip(m1.stages, m2.stages): self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) else: raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) @@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + def test_decisiontree_classifier(self): dt = DecisionTreeClassifier(maxDepth=1) path = tempfile.mkdtemp() @@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) - def test_save_load(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol()) - self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol()) - self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - for m, n in zip(model.models, loadedModel.models): - self.assertEqual(m.uid, n.uid) - class HashingTFTest(SparkSessionTestCase): -- cgit v1.2.3