diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-06-30 10:27:29 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-06-30 10:27:29 -0700 |
commit | 5fa0863626aaf5a9a41756a0b1ec82bddccbf067 (patch) | |
tree | c8844bf78757fff102b41d7dac58bfaee988071a /python/pyspark/ml | |
parent | 45281664e0d3b22cd63660ca8ad6dd574f10e21f (diff) | |
download | spark-5fa0863626aaf5a9a41756a0b1ec82bddccbf067.tar.gz spark-5fa0863626aaf5a9a41756a0b1ec82bddccbf067.tar.bz2 spark-5fa0863626aaf5a9a41756a0b1ec82bddccbf067.zip |
[SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable
It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730)
e.g
def func(example, f={}):
f[example] = 1
return f
func(2)
{2: 1}
func(3)
{2:1, 3:1}
mengxr
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #7058 from MechCoder/pipeline_api_playground and squashes the following commits:
40a5eb2 [MechCoder] copy
95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/pipeline.py | 24 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 4 |
2 files changed, 21 insertions, 7 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a563024b2c..9889f56cac 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -42,7 +42,7 @@ class Estimator(Params): """ raise NotImplementedError() - def fit(self, dataset, params={}): + def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -54,6 +54,8 @@ class Estimator(Params): list of models. :returns: fitted model(s) """ + if params is None: + params = dict() if isinstance(params, (list, tuple)): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): @@ -86,7 +88,7 @@ class Transformer(Params): """ raise NotImplementedError() - def transform(self, dataset, params={}): + def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -96,6 +98,8 @@ class Transformer(Params): params. :returns: transformed dataset """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params,)._transform(dataset) @@ -135,10 +139,12 @@ class Pipeline(Estimator): """ @keyword_only - def __init__(self, stages=[]): + def __init__(self, stages=None): """ __init__(self, stages=[]) """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") @@ -162,11 +168,13 @@ class Pipeline(Estimator): return self._paramMap[self.stages] @keyword_only - def setParams(self, stages=[]): + def setParams(self, stages=None): """ setParams(self, stages=[]) Sets params for Pipeline. """ + if stages is None: + stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -195,7 +203,9 @@ class Pipeline(Estimator): transformers.append(stage) return PipelineModel(transformers) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() that = Params.copy(self, extra) stages = [stage.copy(extra) for stage in that.getStages()] return that.setStages(stages) @@ -216,6 +226,8 @@ class PipelineModel(Model): dataset = t.transform(dataset) return dataset - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7b0893e2cd..253705bde9 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -166,7 +166,7 @@ class JavaModel(Model, JavaTransformer): self._java_obj = java_model self.uid = java_model.uid() - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and @@ -175,6 +175,8 @@ class JavaModel(Model, JavaTransformer): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = super(JavaModel, self).copy(extra) that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() |