diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-22 12:11:23 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-22 12:11:37 -0700 |
commit | 7e3423b9c03c9812d404134c3d204c4cfea87721 (patch) | |
tree | b922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/tests.py | |
parent | 297c20226d3330309c9165d789749458f8f4ab8e (diff) | |
download | spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2 spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip |
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations.
Also:
* Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader.
* Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java
Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11866 from jkbradley/nested-pipeline-io.
Closes #11835
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 82 |
1 files changed, 64 insertions, 18 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9783ce7e77..211248e8b2 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaWrapper): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() @@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase): pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" pl.save(pipeline_path) loaded_pipeline = Pipeline.load(pipeline_path) - self.assertEqual(loaded_pipeline.uid, pl.uid) - self.assertEqual(len(loaded_pipeline.getStages()), 2) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass - [loaded_tf, loaded_pca] = loaded_pipeline.getStages() - self.assertIsInstance(loaded_tf, HashingTF) - self.assertEqual(loaded_tf.uid, tf.uid) - param = loaded_tf.getParam("numFeatures") - self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) - self.assertIsInstance(loaded_pca, PCA) - self.assertEqual(loaded_pca.uid, pca.uid) - self.assertEqual(loaded_pca.getK(), pca.getK()) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) model_path = temp_path + "/pipeline-model" model.save(model_path) loaded_model = PipelineModel.load(model_path) - [model_tf, model_pca] = model.stages - [loaded_model_tf, loaded_model_pca] = loaded_model.stages - self.assertEqual(model_tf.uid, loaded_model_tf.uid) - self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) - - self.assertEqual(model_pca.uid, loaded_model_pca.uid) - self.assertEqual(model_pca.pc, loaded_model_pca.pc) - self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + self._compare_pipelines(model, loaded_model) finally: try: rmtree(temp_path) |