aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-08-19 23:46:36 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-08-19 23:46:36 -0700
commit39f328ba3519b01940a7d1cdee851ba4e75ef31f (patch)
tree467a209b875a76164d11c28c86f84b547fa3215e
parent45d40d9f66c666eec6df926db23937589d67225d (diff)
downloadspark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.tar.gz
spark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.tar.bz2
spark-39f328ba3519b01940a7d1cdee851ba4e75ef31f.zip
[SPARK-15018][PYSPARK][ML] Improve handling of PySpark Pipeline when used without stages
## What changes were proposed in this pull request? When fitting a PySpark Pipeline without the `stages` param set, a confusing NoneType error is raised as attempts to iterate over the pipeline stages. A pipeline with no stages should act as an identity transform, however the `stages` param still needs to be set to an empty list. This change improves the error output when the `stages` param is not set and adds a better description of what the API expects as input. Also minor cleanup of related code. ## How was this patch tested? Added new unit tests to verify an empty Pipeline acts as an identity transformer Author: Bryan Cutler <cutlerb@gmail.com> Closes #12790 from BryanCutler/pipeline-identity-SPARK-15018.
-rw-r--r--python/pyspark/ml/pipeline.py11
-rwxr-xr-xpython/pyspark/ml/tests.py11
2 files changed, 14 insertions, 8 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a48f4bb2ad..4307ad02a0 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -44,21 +44,19 @@ class Pipeline(Estimator, MLReadable, MLWritable):
the dataset for the next stage. The fitted model from a
:py:class:`Pipeline` is a :py:class:`PipelineModel`, which
consists of fitted models and transformers, corresponding to the
- pipeline stages. If there are no stages, the pipeline acts as an
+ pipeline stages. If stages is an empty list, the pipeline acts as an
identity transformer.
.. versionadded:: 1.3.0
"""
- stages = Param(Params._dummy(), "stages", "pipeline stages")
+ stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
@keyword_only
def __init__(self, stages=None):
"""
__init__(self, stages=None)
"""
- if stages is None:
- stages = []
super(Pipeline, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -78,8 +76,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
"""
Get pipeline stages.
"""
- if self.stages in self._paramMap:
- return self._paramMap[self.stages]
+ return self.getOrDefault(self.stages)
@keyword_only
@since("1.3.0")
@@ -88,8 +85,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
setParams(self, stages=None)
Sets params for Pipeline.
"""
- if stages is None:
- stages = []
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4bcb2c400c..6886ed321e 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -230,6 +230,17 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(5, transformer3.dataset_index)
self.assertEqual(6, dataset.index)
+ def test_identity_pipeline(self):
+ dataset = MockDataset()
+
+ def doTransform(pipeline):
+ pipeline_model = pipeline.fit(dataset)
+ return pipeline_model.transform(dataset)
+ # check that empty pipeline did not perform any transformation
+ self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
+ # check that failure to set stages param will raise KeyError for missing param
+ self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
+
class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""