aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r--python/pyspark/ml/param/__init__.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index dc3d23ff16..99d8fa3a5b 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -372,6 +372,7 @@ class Params(Identifiable):
extra = dict()
that = copy.copy(self)
that._paramMap = {}
+ that._defaultParamMap = {}
return self._copyValues(that, extra)
def _shouldOwn(self, param):
@@ -452,12 +453,16 @@ class Params(Identifiable):
:param extra: extra params to be copied
:return: the target instance with param values copied
"""
- if extra is None:
- extra = dict()
- paramMap = self.extractParamMap(extra)
- for p in self.params:
- if p in paramMap and to.hasParam(p.name):
- to._set(**{p.name: paramMap[p]})
+ paramMap = self._paramMap.copy()
+ if extra is not None:
+ paramMap.update(extra)
+ for param in self.params:
+ # copy default params
+ if param in self._defaultParamMap and to.hasParam(param.name):
+ to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
+ # copy explicitly set params
+ if param in paramMap and to.hasParam(param.name):
+ to._set(**{param.name: paramMap[param]})
return to
def _resetUid(self, newUid):