aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-01-26 15:53:48 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-26 15:53:48 -0800
commiteb917291ca1a2d68ca0639cb4b1464a546603eba (patch)
tree380dcaa33273baa68beaf089387bd498d5ee88e8 /python/pyspark/ml/param/__init__.py
parent19fdb21afbf0eae4483cf6d4ef32daffd1994b89 (diff)
downloadspark-eb917291ca1a2d68ca0639cb4b1464a546603eba.tar.gz
spark-eb917291ca1a2d68ca0639cb4b1464a546603eba.tar.bz2
spark-eb917291ca1a2d68ca0639cb4b1464a546603eba.zip
[SPARK-10509][PYSPARK] Reduce excessive param boiler plate code
The current python ml params require cut-and-pasting the param setup and description between the class & ```__init__``` methods. Remove this possible case of errors & simplify use of custom params by adding a ```_copy_new_parent``` method to param so as to avoid cut and pasting (and cut and pasting at different indentation levels urgh). Author: Holden Karau <holden@us.ibm.com> Closes #10216 from holdenk/SPARK-10509-excessive-param-boiler-plate-code.
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 92ce96aa3c..3da36d32c5 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -40,6 +40,15 @@ class Param(object):
self.doc = str(doc)
self.expectedType = expectedType
+ def _copy_new_parent(self, parent):
+ """Copy the current param to a new parent, must be a dummy param."""
+ if self.parent == "undefined":
+ param = copy.copy(self)
+ param.parent = parent.uid
+ return param
+ else:
+ raise ValueError("Cannot copy from non-dummy parent %s." % parent)
+
def __str__(self):
return str(self.parent) + "__" + self.name
@@ -77,6 +86,19 @@ class Params(Identifiable):
#: value returned by :py:func:`params`
self._params = None
+ # Copy the params from the class to the object
+ self._copy_params()
+
+ def _copy_params(self):
+ """
+ Copy all params defined on the class to current object.
+ """
+ cls = type(self)
+ src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)]
+ src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs))
+ for name, param in src_params:
+ setattr(self, name, param._copy_new_parent(self))
+
@property
@since("1.3.0")
def params(self):