aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2d239f8c80..18d8a58f35 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -18,7 +18,7 @@
from abc import ABCMeta, abstractmethod
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import inherit_doc, keyword_only
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
@@ -89,10 +89,16 @@ class Pipeline(Estimator):
identity transformer.
"""
- def __init__(self):
+ @keyword_only
+ def __init__(self, stages=[]):
+ """
+ __init__(self, stages=[])
+ """
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
def setStages(self, value):
"""
@@ -110,6 +116,15 @@ class Pipeline(Estimator):
if self.stages in self.paramMap:
return self.paramMap[self.stages]
+ @keyword_only
+ def setParams(self, stages=[]):
+ """
+ setParams(self, stages=[])
+ Sets params for Pipeline.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set_params(**kwargs)
+
def fit(self, dataset, params={}):
paramMap = self._merge_params(params)
stages = paramMap[self.stages]