aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py82
1 files changed, 64 insertions, 18 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9783ce7e77..211248e8b2 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def _compare_pipelines(self, m1, m2):
+ """
+ Compare 2 ML types, asserting that they are equivalent.
+ This currently supports:
+ - basic types
+ - Pipeline, PipelineModel
+ This checks:
+ - uid
+ - type
+ - Param values and parents
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ if isinstance(m1, JavaWrapper):
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ elif isinstance(m1, Pipeline):
+ self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+ for s1, s2 in zip(m1.getStages(), m2.getStages()):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, PipelineModel):
+ self.assertEqual(len(m1.stages), len(m2.stages))
+ for s1, s2 in zip(m1.stages, m2.stages):
+ self._compare_pipelines(s1, s2)
+ else:
+ raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
+
def test_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, PCA]
+ """
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
@@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase):
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
pl = Pipeline(stages=[tf, pca])
model = pl.fit(df)
+
pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
- self.assertEqual(loaded_pipeline.uid, pl.uid)
- self.assertEqual(len(loaded_pipeline.getStages()), 2)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
- [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
- self.assertIsInstance(loaded_tf, HashingTF)
- self.assertEqual(loaded_tf.uid, tf.uid)
- param = loaded_tf.getParam("numFeatures")
- self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+ def test_nested_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, Pipeline[PCA]]
+ """
+ sqlContext = SQLContext(self.sc)
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ p0 = Pipeline(stages=[pca])
+ pl = Pipeline(stages=[tf, p0])
+ model = pl.fit(df)
- self.assertIsInstance(loaded_pca, PCA)
- self.assertEqual(loaded_pca.uid, pca.uid)
- self.assertEqual(loaded_pca.getK(), pca.getK())
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
- [model_tf, model_pca] = model.stages
- [loaded_model_tf, loaded_model_pca] = loaded_model.stages
- self.assertEqual(model_tf.uid, loaded_model_tf.uid)
- self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
-
- self.assertEqual(model_pca.uid, loaded_model_pca.uid)
- self.assertEqual(model_pca.pc, loaded_model_pca.pc)
- self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+ self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)