From 5fa0863626aaf5a9a41756a0b1ec82bddccbf067 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 30 Jun 2015 10:27:29 -0700 Subject: [SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730) e.g def func(example, f={}): f[example] = 1 return f func(2) {2: 1} func(3) {2:1, 3:1} mengxr Author: MechCoder Closes #7058 from MechCoder/pipeline_api_playground and squashes the following commits: 40a5eb2 [MechCoder] copy 95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable --- python/pyspark/ml/pipeline.py | 24 ++++++++++++++++++------ python/pyspark/ml/wrapper.py | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a563024b2c..9889f56cac 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -42,7 +42,7 @@ class Estimator(Params): """ raise NotImplementedError() - def fit(self, dataset, params={}): + def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -54,6 +54,8 @@ class Estimator(Params): list of models. :returns: fitted model(s) """ + if params is None: + params = dict() if isinstance(params, (list, tuple)): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): @@ -86,7 +88,7 @@ class Transformer(Params): """ raise NotImplementedError() - def transform(self, dataset, params={}): + def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -96,6 +98,8 @@ class Transformer(Params): params. :returns: transformed dataset """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params,)._transform(dataset) @@ -135,10 +139,12 @@ class Pipeline(Estimator): """ @keyword_only - def __init__(self, stages=[]): + def __init__(self, stages=None): """ __init__(self, stages=[]) """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") @@ -162,11 +168,13 @@ class Pipeline(Estimator): return self._paramMap[self.stages] @keyword_only - def setParams(self, stages=[]): + def setParams(self, stages=None): """ setParams(self, stages=[]) Sets params for Pipeline. """ + if stages is None: + stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -195,7 +203,9 @@ class Pipeline(Estimator): transformers.append(stage) return PipelineModel(transformers) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() that = Params.copy(self, extra) stages = [stage.copy(extra) for stage in that.getStages()] return that.setStages(stages) @@ -216,6 +226,8 @@ class PipelineModel(Model): dataset = t.transform(dataset) return dataset - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7b0893e2cd..253705bde9 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -166,7 +166,7 @@ class JavaModel(Model, JavaTransformer): self._java_obj = java_model self.uid = java_model.uid() - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and @@ -175,6 +175,8 @@ class JavaModel(Model, JavaTransformer): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = super(JavaModel, self).copy(extra) that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() -- cgit v1.2.3