aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-16 13:49:40 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 13:49:40 -0700
commitae6c677c8a03174787be99af6238a5e1fbe4e389 (patch)
tree75943410b6cfbe50c66ff199ab6164d24edeef84 /python/pyspark/ml/tests.py
parentc4bd57602c0b14188d364bb475631bf473d25082 (diff)
downloadspark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.gz
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.bz2
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.zip
[SPARK-13038][PYSPARK] Add load/save to pipeline
## What changes were proposed in this pull request? JIRA issue: https://issues.apache.org/jira/browse/SPARK-13038 1. Add load/save to PySpark Pipeline and PipelineModel 2. Add `_transfer_stage_to_java()` and `_transfer_stage_from_java()` for `JavaWrapper`. ## How was this patch tested? Test with doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #11683 from yinxusen/SPARK-13038-only.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py45
1 files changed, 44 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4da9a373e9..c76f893e43 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -37,7 +37,7 @@ else:
from shutil import rmtree
import tempfile
-from pyspark.ml import Estimator, Model, Pipeline, Transformer
+from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
@@ -499,6 +499,49 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_pipeline_persistence(self):
+ sqlContext = SQLContext(self.sc)
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ pl = Pipeline(stages=[tf, pca])
+ model = pl.fit(df)
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self.assertEqual(loaded_pipeline.uid, pl.uid)
+ self.assertEqual(len(loaded_pipeline.getStages()), 2)
+
+ [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
+ self.assertIsInstance(loaded_tf, HashingTF)
+ self.assertEqual(loaded_tf.uid, tf.uid)
+ param = loaded_tf.getParam("numFeatures")
+ self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+
+ self.assertIsInstance(loaded_pca, PCA)
+ self.assertEqual(loaded_pca.uid, pca.uid)
+ self.assertEqual(loaded_pca.getK(), pca.getK())
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ [model_tf, model_pca] = model.stages
+ [loaded_model_tf, loaded_model_pca] = loaded_model.stages
+ self.assertEqual(model_tf.uid, loaded_model_tf.uid)
+ self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
+
+ self.assertEqual(model_pca.uid, loaded_model_pca.uid)
+ self.assertEqual(model_pca.pc, loaded_model_pca.pc)
+ self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
class HasThrowableProperty(Params):