aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-30 10:27:29 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-30 10:27:37 -0700
commit894404cb237f8e5fc2b73ac36468f1af524a4238 (patch)
tree1d16206128c49f2ded65f5d6c23320b862a19700
parent255b2be94bbd2b527175d8e7a5a2b89fecf8a835 (diff)
downloadspark-894404cb237f8e5fc2b73ac36468f1af524a4238.tar.gz
spark-894404cb237f8e5fc2b73ac36468f1af524a4238.tar.bz2
spark-894404cb237f8e5fc2b73ac36468f1af524a4238.zip
[SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable
It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730) e.g def func(example, f={}): f[example] = 1 return f func(2) {2: 1} func(3) {2:1, 3:1} mengxr Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7058 from MechCoder/pipeline_api_playground and squashes the following commits: 40a5eb2 [MechCoder] copy 95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable (cherry picked from commit 5fa0863626aaf5a9a41756a0b1ec82bddccbf067) Signed-off-by: Xiangrui Meng <meng@databricks.com>
-rw-r--r--python/pyspark/ml/pipeline.py24
-rw-r--r--python/pyspark/ml/wrapper.py4
2 files changed, 21 insertions, 7 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a563024b2c..9889f56cac 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -42,7 +42,7 @@ class Estimator(Params):
"""
raise NotImplementedError()
- def fit(self, dataset, params={}):
+ def fit(self, dataset, params=None):
"""
Fits a model to the input dataset with optional parameters.
@@ -54,6 +54,8 @@ class Estimator(Params):
list of models.
:returns: fitted model(s)
"""
+ if params is None:
+ params = dict()
if isinstance(params, (list, tuple)):
return [self.fit(dataset, paramMap) for paramMap in params]
elif isinstance(params, dict):
@@ -86,7 +88,7 @@ class Transformer(Params):
"""
raise NotImplementedError()
- def transform(self, dataset, params={}):
+ def transform(self, dataset, params=None):
"""
Transforms the input dataset with optional parameters.
@@ -96,6 +98,8 @@ class Transformer(Params):
params.
:returns: transformed dataset
"""
+ if params is None:
+ params = dict()
if isinstance(params, dict):
if params:
return self.copy(params,)._transform(dataset)
@@ -135,10 +139,12 @@ class Pipeline(Estimator):
"""
@keyword_only
- def __init__(self, stages=[]):
+ def __init__(self, stages=None):
"""
__init__(self, stages=[])
"""
+ if stages is None:
+ stages = []
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")
@@ -162,11 +168,13 @@ class Pipeline(Estimator):
return self._paramMap[self.stages]
@keyword_only
- def setParams(self, stages=[]):
+ def setParams(self, stages=None):
"""
setParams(self, stages=[])
Sets params for Pipeline.
"""
+ if stages is None:
+ stages = []
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
@@ -195,7 +203,9 @@ class Pipeline(Estimator):
transformers.append(stage)
return PipelineModel(transformers)
- def copy(self, extra={}):
+ def copy(self, extra=None):
+ if extra is None:
+ extra = dict()
that = Params.copy(self, extra)
stages = [stage.copy(extra) for stage in that.getStages()]
return that.setStages(stages)
@@ -216,6 +226,8 @@ class PipelineModel(Model):
dataset = t.transform(dataset)
return dataset
- def copy(self, extra={}):
+ def copy(self, extra=None):
+ if extra is None:
+ extra = dict()
stages = [stage.copy(extra) for stage in self.stages]
return PipelineModel(stages)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7b0893e2cd..253705bde9 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -166,7 +166,7 @@ class JavaModel(Model, JavaTransformer):
self._java_obj = java_model
self.uid = java_model.uid()
- def copy(self, extra={}):
+ def copy(self, extra=None):
"""
Creates a copy of this instance with the same uid and some
extra params. This implementation first calls Params.copy and
@@ -175,6 +175,8 @@ class JavaModel(Model, JavaTransformer):
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
+ if extra is None:
+ extra = dict()
that = super(JavaModel, self).copy(extra)
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()