diff options
author | Bryan Cutler <cutlerb@gmail.com> | 2016-08-19 23:46:36 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-08-19 23:46:36 -0700 |
commit | 39f328ba3519b01940a7d1cdee851ba4e75ef31f (patch) | |
tree | 467a209b875a76164d11c28c86f84b547fa3215e /python | |
parent | 45d40d9f66c666eec6df926db23937589d67225d (diff) | |
download | spark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.tar.gz spark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.tar.bz2 spark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.zip |
[SPARK-15018][PYSPARK][ML] Improve handling of PySpark Pipeline when used without stages
## What changes were proposed in this pull request?
When fitting a PySpark Pipeline without the `stages` param set, a confusing NoneType error is raised as attempts to iterate over the pipeline stages. A pipeline with no stages should act as an identity transform, however the `stages` param still needs to be set to an empty list. This change improves the error output when the `stages` param is not set and adds a better description of what the API expects as input. Also minor cleanup of related code.
## How was this patch tested?
Added new unit tests to verify an empty Pipeline acts as an identity transformer
Author: Bryan Cutler <cutlerb@gmail.com>
Closes #12790 from BryanCutler/pipeline-identity-SPARK-15018.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/pipeline.py | 11 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 11 |
2 files changed, 14 insertions, 8 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a48f4bb2ad..4307ad02a0 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -44,21 +44,19 @@ class Pipeline(Estimator, MLReadable, MLWritable): the dataset for the next stage. The fitted model from a :py:class:`Pipeline` is a :py:class:`PipelineModel`, which consists of fitted models and transformers, corresponding to the - pipeline stages. If there are no stages, the pipeline acts as an + pipeline stages. If stages is an empty list, the pipeline acts as an identity transformer. .. versionadded:: 1.3.0 """ - stages = Param(Params._dummy(), "stages", "pipeline stages") + stages = Param(Params._dummy(), "stages", "a list of pipeline stages") @keyword_only def __init__(self, stages=None): """ __init__(self, stages=None) """ - if stages is None: - stages = [] super(Pipeline, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -78,8 +76,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): """ Get pipeline stages. """ - if self.stages in self._paramMap: - return self._paramMap[self.stages] + return self.getOrDefault(self.stages) @keyword_only @since("1.3.0") @@ -88,8 +85,6 @@ class Pipeline(Estimator, MLReadable, MLWritable): setParams(self, stages=None) Sets params for Pipeline. """ - if stages is None: - stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4bcb2c400c..6886ed321e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -230,6 +230,17 @@ class PipelineTests(PySparkTestCase): self.assertEqual(5, transformer3.dataset_index) self.assertEqual(6, dataset.index) + def test_identity_pipeline(self): + dataset = MockDataset() + + def doTransform(pipeline): + pipeline_model = pipeline.fit(dataset) + return pipeline_model.transform(dataset) + # check that empty pipeline did not perform any transformation + self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) + # check that failure to set stages param will raise KeyError for missing param + self.assertRaises(KeyError, lambda: doTransform(Pipeline())) + class TestParams(HasMaxIter, HasInputCol, HasSeed): """ |