diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
commit | 9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch) | |
tree | 2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /python/pyspark/ml/pipeline.py | |
parent | 56ede88485cfca90974425fcb603b257be47229b (diff) | |
download | spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2 spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip |
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes:
1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively.
2. Accept a list of param maps in `fit`.
3. Use parent uid and name to identify param.
jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6088 from mengxr/SPARK-7380 and squashes the following commits:
413c463 [Xiangrui Meng] remove unnecessary doc
4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
611c719 [Xiangrui Meng] fix python style
68862b8 [Xiangrui Meng] update _java_obj initialization
927ad19 [Xiangrui Meng] fix ml/tests.py
0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer
9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests
c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params
7e0d27f [Xiangrui Meng] merge master
46840fb [Xiangrui Meng] update wrappers
b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap
46cb6ed [Xiangrui Meng] merge master
a163413 [Xiangrui Meng] fix style
1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
9630eae [Xiangrui Meng] fix Identifiable._randomUID
13bd70a [Xiangrui Meng] update ml/tests.py
64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl
02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python
66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui
7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r-- | python/pyspark/ml/pipeline.py | 109 |
1 files changed, 85 insertions, 24 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a328bcf84a..0f38e02127 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -31,18 +31,40 @@ class Estimator(Params): __metaclass__ = ABCMeta @abstractmethod - def fit(self, dataset, params={}): + def _fit(self, dataset): """ - Fits a model to the input dataset with optional parameters. + Fits a model to the input dataset. This is called by the + default implementation of fit. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: fitted model """ raise NotImplementedError() + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. If a list/tuple of param maps is given, + this calls fit on each param map and returns a + list of models. + :returns: fitted model(s) + """ + if isinstance(params, (list, tuple)): + return [self.fit(dataset, paramMap) for paramMap in params] + elif isinstance(params, dict): + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) + else: + raise ValueError("Params must be either a param map or a list/tuple of param maps, " + "but got %s." % type(params)) + @inherit_doc class Transformer(Params): @@ -54,18 +76,34 @@ class Transformer(Params): __metaclass__ = ABCMeta @abstractmethod - def transform(self, dataset, params={}): + def _transform(self, dataset): """ Transforms the input dataset with optional parameters. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: transformed dataset """ raise NotImplementedError() + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. + :returns: transformed dataset + """ + if isinstance(params, dict): + if params: + return self.copy(params,)._transform(dataset) + else: + return self._transform(dataset) + else: + raise ValueError("Params must be either a param map but got %s." % type(params)) + @inherit_doc class Model(Transformer): @@ -113,15 +151,15 @@ class Pipeline(Estimator): :param value: a list of transformers or estimators :return: the pipeline instance """ - self.paramMap[self.stages] = value + self._paramMap[self.stages] = value return self def getStages(self): """ Get pipeline stages. """ - if self.stages in self.paramMap: - return self.paramMap[self.stages] + if self.stages in self._paramMap: + return self._paramMap[self.stages] @keyword_only def setParams(self, stages=[]): @@ -132,9 +170,8 @@ class Pipeline(Estimator): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - stages = paramMap[self.stages] + def _fit(self, dataset): + stages = self.getStages() for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): raise TypeError( @@ -148,16 +185,21 @@ class Pipeline(Estimator): if i <= indexOfLastEstimator: if isinstance(stage, Transformer): transformers.append(stage) - dataset = stage.transform(dataset, paramMap) + dataset = stage.transform(dataset) else: # must be an Estimator - model = stage.fit(dataset, paramMap) + model = stage.fit(dataset) transformers.append(model) if i < indexOfLastEstimator: - dataset = model.transform(dataset, paramMap) + dataset = model.transform(dataset) else: transformers.append(stage) return PipelineModel(transformers) + def copy(self, extra={}): + that = Params.copy(self, extra) + stages = [stage.copy(extra) for stage in that.getStages()] + return that.setStages(stages) + @inherit_doc class PipelineModel(Model): @@ -165,16 +207,19 @@ class PipelineModel(Model): Represents a compiled pipeline with transformers and fitted models. """ - def __init__(self, transformers): + def __init__(self, stages): super(PipelineModel, self).__init__() - self.transformers = transformers + self.stages = stages - def transform(self, dataset, params={}): - paramMap = self.extractParamMap(params) - for t in self.transformers: - dataset = t.transform(dataset, paramMap) + def _transform(self, dataset): + for t in self.stages: + dataset = t.transform(dataset) return dataset + def copy(self, extra={}): + stages = [stage.copy(extra) for stage in self.stages] + return PipelineModel(stages) + class Evaluator(Params): """ @@ -184,14 +229,30 @@ class Evaluator(Params): __metaclass__ = ABCMeta @abstractmethod - def evaluate(self, dataset, params={}): + def _evaluate(self, dataset): """ Evaluates the output. :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params={}): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and predictions :param params: an optional param map that overrides embedded params :return: metric """ - raise NotImplementedError() + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) |