diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-03-16 13:49:40 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-16 13:49:40 -0700 |
commit | ae6c677c8a03174787be99af6238a5e1fbe4e389 (patch) | |
tree | 75943410b6cfbe50c66ff199ab6164d24edeef84 /python/pyspark/ml/tests.py | |
parent | c4bd57602c0b14188d364bb475631bf473d25082 (diff) | |
download | spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.gz spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.bz2 spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.zip |
[SPARK-13038][PYSPARK] Add load/save to pipeline
## What changes were proposed in this pull request?
JIRA issue: https://issues.apache.org/jira/browse/SPARK-13038
1. Add load/save to PySpark Pipeline and PipelineModel
2. Add `_transfer_stage_to_java()` and `_transfer_stage_from_java()` for `JavaWrapper`.
## How was this patch tested?
Test with doctest.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #11683 from yinxusen/SPARK-13038-only.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4da9a373e9..c76f893e43 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -37,7 +37,7 @@ else: from shutil import rmtree import tempfile -from pyspark.ml import Estimator, Model, Pipeline, Transformer +from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator @@ -499,6 +499,49 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def test_pipeline_persistence(self): + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + pl = Pipeline(stages=[tf, pca]) + model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self.assertEqual(loaded_pipeline.uid, pl.uid) + self.assertEqual(len(loaded_pipeline.getStages()), 2) + + [loaded_tf, loaded_pca] = loaded_pipeline.getStages() + self.assertIsInstance(loaded_tf, HashingTF) + self.assertEqual(loaded_tf.uid, tf.uid) + param = loaded_tf.getParam("numFeatures") + self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + + self.assertIsInstance(loaded_pca, PCA) + self.assertEqual(loaded_pca.uid, pca.uid) + self.assertEqual(loaded_pca.getK(), pca.getK()) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + [model_tf, model_pca] = model.stages + [loaded_model_tf, loaded_model_pca] = loaded_model.stages + self.assertEqual(model_tf.uid, loaded_model_tf.uid) + self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) + + self.assertEqual(model_pca.uid, loaded_model_pca.uid) + self.assertEqual(model_pca.pc, loaded_model_pca.pc) + self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + finally: + try: + rmtree(temp_path) + except OSError: + pass + class HasThrowableProperty(Params): |