aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2017-02-23 18:05:58 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-02-23 18:05:58 -0800
commit2f69e3f60f27d4598f001a5454abc21f739120a6 (patch)
treee574fdd402e80c333886d71d664fa865b861bf31 /python/pyspark
parentd02762457477b6ab1323d7a97901f48a273ee644 (diff)
downloadspark-2f69e3f60f27d4598f001a5454abc21f739120a6.tar.gz
spark-2f69e3f60f27d4598f001a5454abc21f739120a6.tar.bz2
spark-2f69e3f60f27d4598f001a5454abc21f739120a6.zip
[SPARK-14772][PYTHON][ML] Fixed Params.copy method to match Scala implementation
## What changes were proposed in this pull request? Fixed the PySpark Params.copy method to behave like the Scala implementation. The main issue was that it did not account for the _defaultParamMap and merged it into the explicitly created param map. ## How was this patch tested? Added new unit test to verify the copy method behaves correctly for copying uid, explicitly created params, and default params. Author: Bryan Cutler <cutlerb@gmail.com> Closes #16772 from BryanCutler/pyspark-ml-param_copy-Scala_sync-SPARK-14772.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/param/__init__.py17
-rwxr-xr-xpython/pyspark/ml/tests.py16
2 files changed, 27 insertions, 6 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index dc3d23ff16..99d8fa3a5b 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -372,6 +372,7 @@ class Params(Identifiable):
extra = dict()
that = copy.copy(self)
that._paramMap = {}
+ that._defaultParamMap = {}
return self._copyValues(that, extra)
def _shouldOwn(self, param):
@@ -452,12 +453,16 @@ class Params(Identifiable):
:param extra: extra params to be copied
:return: the target instance with param values copied
"""
- if extra is None:
- extra = dict()
- paramMap = self.extractParamMap(extra)
- for p in self.params:
- if p in paramMap and to.hasParam(p.name):
- to._set(**{p.name: paramMap[p]})
+ paramMap = self._paramMap.copy()
+ if extra is not None:
+ paramMap.update(extra)
+ for param in self.params:
+ # copy default params
+ if param in self._defaultParamMap and to.hasParam(param.name):
+ to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
+ # copy explicitly set params
+ if param in paramMap and to.hasParam(param.name):
+ to._set(**{param.name: paramMap[param]})
return to
def _resetUid(self, newUid):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 53204cde29..293c6c0b0f 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -389,6 +389,22 @@ class ParamTests(PySparkTestCase):
# Check windowSize is set properly
self.assertEqual(model.getWindowSize(), 6)
+ def test_copy_param_extras(self):
+ tp = TestParams(seed=42)
+ extra = {tp.getParam(TestParams.inputCol.name): "copy_input"}
+ tp_copy = tp.copy(extra=extra)
+ self.assertEqual(tp.uid, tp_copy.uid)
+ self.assertEqual(tp.params, tp_copy.params)
+ for k, v in extra.items():
+ self.assertTrue(tp_copy.isDefined(k))
+ self.assertEqual(tp_copy.getOrDefault(k), v)
+ copied_no_extra = {}
+ for k, v in tp_copy._paramMap.items():
+ if k not in extra:
+ copied_no_extra[k] = v
+ self.assertEqual(tp._paramMap, copied_no_extra)
+ self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
+
class EvaluatorTests(SparkSessionTestCase):