aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-22 12:11:23 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 12:11:37 -0700
commit7e3423b9c03c9812d404134c3d204c4cfea87721 (patch)
treeb922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/tests.py
parent297c20226d3330309c9165d789749458f8f4ab8e (diff)
downloadspark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py82
1 files changed, 64 insertions, 18 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9783ce7e77..211248e8b2 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def _compare_pipelines(self, m1, m2):
+ """
+ Compare 2 ML types, asserting that they are equivalent.
+ This currently supports:
+ - basic types
+ - Pipeline, PipelineModel
+ This checks:
+ - uid
+ - type
+ - Param values and parents
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ if isinstance(m1, JavaWrapper):
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ elif isinstance(m1, Pipeline):
+ self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+ for s1, s2 in zip(m1.getStages(), m2.getStages()):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, PipelineModel):
+ self.assertEqual(len(m1.stages), len(m2.stages))
+ for s1, s2 in zip(m1.stages, m2.stages):
+ self._compare_pipelines(s1, s2)
+ else:
+ raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
+
def test_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, PCA]
+ """
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
@@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase):
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
pl = Pipeline(stages=[tf, pca])
model = pl.fit(df)
+
pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
- self.assertEqual(loaded_pipeline.uid, pl.uid)
- self.assertEqual(len(loaded_pipeline.getStages()), 2)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
- [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
- self.assertIsInstance(loaded_tf, HashingTF)
- self.assertEqual(loaded_tf.uid, tf.uid)
- param = loaded_tf.getParam("numFeatures")
- self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+ def test_nested_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, Pipeline[PCA]]
+ """
+ sqlContext = SQLContext(self.sc)
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ p0 = Pipeline(stages=[pca])
+ pl = Pipeline(stages=[tf, p0])
+ model = pl.fit(df)
- self.assertIsInstance(loaded_pca, PCA)
- self.assertEqual(loaded_pca.uid, pca.uid)
- self.assertEqual(loaded_pca.getK(), pca.getK())
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
- [model_tf, model_pca] = model.stages
- [loaded_model_tf, loaded_model_pca] = loaded_model.stages
- self.assertEqual(model_tf.uid, loaded_model_tf.uid)
- self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
-
- self.assertEqual(model_pca.uid, loaded_model_pca.uid)
- self.assertEqual(model_pca.pc, loaded_model_pca.pc)
- self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+ self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)