From 7e3423b9c03c9812d404134c3d204c4cfea87721 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 22 Mar 2016 12:11:23 -0700 Subject: [SPARK-13951][ML][PYTHON] Nested Pipeline persistence Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835 --- python/pyspark/ml/tests.py | 82 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 18 deletions(-) (limited to 'python/pyspark/ml/tests.py') diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9783ce7e77..211248e8b2 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaWrapper): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() @@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase): pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" pl.save(pipeline_path) loaded_pipeline = Pipeline.load(pipeline_path) - self.assertEqual(loaded_pipeline.uid, pl.uid) - self.assertEqual(len(loaded_pipeline.getStages()), 2) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass - [loaded_tf, loaded_pca] = loaded_pipeline.getStages() - self.assertIsInstance(loaded_tf, HashingTF) - self.assertEqual(loaded_tf.uid, tf.uid) - param = loaded_tf.getParam("numFeatures") - self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) - self.assertIsInstance(loaded_pca, PCA) - self.assertEqual(loaded_pca.uid, pca.uid) - self.assertEqual(loaded_pca.getK(), pca.getK()) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) model_path = temp_path + "/pipeline-model" model.save(model_path) loaded_model = PipelineModel.load(model_path) - [model_tf, model_pca] = model.stages - [loaded_model_tf, loaded_model_pca] = loaded_model.stages - self.assertEqual(model_tf.uid, loaded_model_tf.uid) - self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) - - self.assertEqual(model_pca.uid, loaded_model_pca.uid) - self.assertEqual(model_pca.pc, loaded_model_pca.pc) - self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + self._compare_pipelines(model, loaded_model) finally: try: rmtree(temp_path) -- cgit v1.2.3