aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py146
1 files changed, 127 insertions, 19 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index e3a53dd780..5c62620562 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -25,23 +25,21 @@ __all__ = ['Param', 'Params']
class Param(object):
"""
- A param with self-contained documentation and optionally default value.
+ A param with self-contained documentation.
"""
- def __init__(self, parent, name, doc, defaultValue=None):
- if not isinstance(parent, Identifiable):
- raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+ def __init__(self, parent, name, doc):
+ if not isinstance(parent, Params):
+ raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
self.parent = parent
self.name = str(name)
self.doc = str(doc)
- self.defaultValue = defaultValue
def __str__(self):
- return str(self.parent) + "-" + self.name
+ return str(self.parent) + "__" + self.name
def __repr__(self):
- return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
- (self.parent, self.name, self.doc, self.defaultValue)
+ return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
class Params(Identifiable):
@@ -52,26 +50,128 @@ class Params(Identifiable):
__metaclass__ = ABCMeta
- def __init__(self):
- super(Params, self).__init__()
- #: embedded param map
- self.paramMap = {}
+ #: internal param map for user-supplied values param map
+ paramMap = {}
+
+ #: internal param map for default values
+ defaultParamMap = {}
@property
def params(self):
"""
- Returns all params. The default implementation uses
- :py:func:`dir` to get all attributes of type
+ Returns all params ordered by name. The default implementation
+ uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"])
- def _merge_params(self, params):
- paramMap = self.paramMap.copy()
- paramMap.update(params)
+ def _explain(self, param):
+ """
+ Explains a single param and returns its name, doc, and optional
+ default value and user-supplied value in a string.
+ """
+ param = self._resolveParam(param)
+ values = []
+ if self.isDefined(param):
+ if param in self.defaultParamMap:
+ values.append("default: %s" % self.defaultParamMap[param])
+ if param in self.paramMap:
+ values.append("current: %s" % self.paramMap[param])
+ else:
+ values.append("undefined")
+ valueStr = "(" + ", ".join(values) + ")"
+ return "%s: %s %s" % (param.name, param.doc, valueStr)
+
+ def explainParams(self):
+ """
+ Returns the documentation of all params with their optionally
+ default values and user-supplied values.
+ """
+ return "\n".join([self._explain(param) for param in self.params])
+
+ def getParam(self, paramName):
+ """
+ Gets a param by its name.
+ """
+ param = getattr(self, paramName)
+ if isinstance(param, Param):
+ return param
+ else:
+ raise ValueError("Cannot find param with name %s." % paramName)
+
+ def isSet(self, param):
+ """
+ Checks whether a param is explicitly set by user.
+ """
+ param = self._resolveParam(param)
+ return param in self.paramMap
+
+ def hasDefault(self, param):
+ """
+ Checks whether a param has a default value.
+ """
+ param = self._resolveParam(param)
+ return param in self.defaultParamMap
+
+ def isDefined(self, param):
+ """
+ Checks whether a param is explicitly set by user or has a default value.
+ """
+ return self.isSet(param) or self.hasDefault(param)
+
+ def getOrDefault(self, param):
+ """
+ Gets the value of a param in the user-supplied param map or its
+ default value. Raises an error if either is set.
+ """
+ if isinstance(param, Param):
+ if param in self.paramMap:
+ return self.paramMap[param]
+ else:
+ return self.defaultParamMap[param]
+ elif isinstance(param, str):
+ return self.getOrDefault(self.getParam(param))
+ else:
+ raise KeyError("Cannot recognize %r as a param." % param)
+
+ def extractParamMap(self, extraParamMap={}):
+ """
+ Extracts the embedded default param values and user-supplied
+ values, and then merges them with extra values from input into
+ a flat param map, where the latter value is used if there exist
+ conflicts, i.e., with ordering: default param values <
+ user-supplied values < extraParamMap.
+ :param extraParamMap: extra param values
+ :return: merged param map
+ """
+ paramMap = self.defaultParamMap.copy()
+ paramMap.update(self.paramMap)
+ paramMap.update(extraParamMap)
return paramMap
+ def _shouldOwn(self, param):
+ """
+ Validates that the input param belongs to this Params instance.
+ """
+ if param.parent is not self:
+ raise ValueError("Param %r does not belong to %r." % (param, self))
+
+ def _resolveParam(self, param):
+ """
+ Resolves a param and validates the ownership.
+ :param param: param name or the param instance, which must
+ belong to this Params instance
+ :return: resolved param instance
+ """
+ if isinstance(param, Param):
+ self._shouldOwn(param)
+ return param
+ elif isinstance(param, str):
+ return self.getParam(param)
+ else:
+ raise ValueError("Cannot resolve %r as a param." % param)
+
@staticmethod
def _dummy():
"""
@@ -81,10 +181,18 @@ class Params(Identifiable):
dummy.uid = "undefined"
return dummy
- def _set_params(self, **kwargs):
+ def _set(self, **kwargs):
"""
- Sets params.
+ Sets user-supplied params.
"""
for param, value in kwargs.iteritems():
self.paramMap[getattr(self, param)] = value
return self
+
+ def _setDefault(self, **kwargs):
+ """
+ Sets default params.
+ """
+ for param, value in kwargs.iteritems():
+ self.defaultParamMap[getattr(self, param)] = value
+ return self