aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py118
1 files changed, 84 insertions, 34 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 49c20b4cf7..67fb6e3dc7 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -16,6 +16,7 @@
#
from abc import ABCMeta
+import copy
from pyspark.ml.util import Identifiable
@@ -29,9 +30,9 @@ class Param(object):
"""
def __init__(self, parent, name, doc):
- if not isinstance(parent, Params):
- raise TypeError("Parent must be a Params but got type %s." % type(parent))
- self.parent = parent
+ if not isinstance(parent, Identifiable):
+ raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
+ self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
@@ -41,6 +42,15 @@ class Param(object):
def __repr__(self):
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
+ def __hash__(self):
+ return hash(str(self))
+
+ def __eq__(self, other):
+ if isinstance(other, Param):
+ return self.parent == other.parent and self.name == other.name
+ else:
+ return False
+
class Params(Identifiable):
"""
@@ -51,10 +61,13 @@ class Params(Identifiable):
__metaclass__ = ABCMeta
#: internal param map for user-supplied values param map
- paramMap = {}
+ _paramMap = {}
#: internal param map for default values
- defaultParamMap = {}
+ _defaultParamMap = {}
+
+ #: value returned by :py:func:`params`
+ _params = None
@property
def params(self):
@@ -63,10 +76,12 @@ class Params(Identifiable):
uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
- return list(filter(lambda attr: isinstance(attr, Param),
- [getattr(self, x) for x in dir(self) if x != "params"]))
+ if self._params is None:
+ self._params = list(filter(lambda attr: isinstance(attr, Param),
+ [getattr(self, x) for x in dir(self) if x != "params"]))
+ return self._params
- def _explain(self, param):
+ def explainParam(self, param):
"""
Explains a single param and returns its name, doc, and optional
default value and user-supplied value in a string.
@@ -74,10 +89,10 @@ class Params(Identifiable):
param = self._resolveParam(param)
values = []
if self.isDefined(param):
- if param in self.defaultParamMap:
- values.append("default: %s" % self.defaultParamMap[param])
- if param in self.paramMap:
- values.append("current: %s" % self.paramMap[param])
+ if param in self._defaultParamMap:
+ values.append("default: %s" % self._defaultParamMap[param])
+ if param in self._paramMap:
+ values.append("current: %s" % self._paramMap[param])
else:
values.append("undefined")
valueStr = "(" + ", ".join(values) + ")"
@@ -88,7 +103,7 @@ class Params(Identifiable):
Returns the documentation of all params with their optionally
default values and user-supplied values.
"""
- return "\n".join([self._explain(param) for param in self.params])
+ return "\n".join([self.explainParam(param) for param in self.params])
def getParam(self, paramName):
"""
@@ -105,56 +120,76 @@ class Params(Identifiable):
Checks whether a param is explicitly set by user.
"""
param = self._resolveParam(param)
- return param in self.paramMap
+ return param in self._paramMap
def hasDefault(self, param):
"""
Checks whether a param has a default value.
"""
param = self._resolveParam(param)
- return param in self.defaultParamMap
+ return param in self._defaultParamMap
def isDefined(self, param):
"""
- Checks whether a param is explicitly set by user or has a default value.
+ Checks whether a param is explicitly set by user or has
+ a default value.
"""
return self.isSet(param) or self.hasDefault(param)
+ def hasParam(self, paramName):
+ """
+ Tests whether this instance contains a param with a given
+ (string) name.
+ """
+ param = self._resolveParam(paramName)
+ return param in self.params
+
def getOrDefault(self, param):
"""
Gets the value of a param in the user-supplied param map or its
default value. Raises an error if either is set.
"""
- if isinstance(param, Param):
- if param in self.paramMap:
- return self.paramMap[param]
- else:
- return self.defaultParamMap[param]
- elif isinstance(param, str):
- return self.getOrDefault(self.getParam(param))
+ param = self._resolveParam(param)
+ if param in self._paramMap:
+ return self._paramMap[param]
else:
- raise KeyError("Cannot recognize %r as a param." % param)
+ return self._defaultParamMap[param]
- def extractParamMap(self, extraParamMap={}):
+ def extractParamMap(self, extra={}):
"""
Extracts the embedded default param values and user-supplied
values, and then merges them with extra values from input into
a flat param map, where the latter value is used if there exist
conflicts, i.e., with ordering: default param values <
- user-supplied values < extraParamMap.
- :param extraParamMap: extra param values
+ user-supplied values < extra.
+ :param extra: extra param values
:return: merged param map
"""
- paramMap = self.defaultParamMap.copy()
- paramMap.update(self.paramMap)
- paramMap.update(extraParamMap)
+ paramMap = self._defaultParamMap.copy()
+ paramMap.update(self._paramMap)
+ paramMap.update(extra)
return paramMap
+ def copy(self, extra={}):
+ """
+ Creates a copy of this instance with the same uid and some
+ extra params. The default implementation creates a
+ shallow copy using :py:func:`copy.copy`, and then copies the
+ embedded and extra parameters over and returns the copy.
+ Subclasses should override this method if the default approach
+ is not sufficient.
+ :param extra: Extra parameters to copy to the new instance
+ :return: Copy of this instance
+ """
+ that = copy.copy(self)
+ that._paramMap = self.extractParamMap(extra)
+ return that
+
def _shouldOwn(self, param):
"""
Validates that the input param belongs to this Params instance.
"""
- if param.parent is not self:
+ if not (self.uid == param.parent and self.hasParam(param.name)):
raise ValueError("Param %r does not belong to %r." % (param, self))
def _resolveParam(self, param):
@@ -175,7 +210,8 @@ class Params(Identifiable):
@staticmethod
def _dummy():
"""
- Returns a dummy Params instance used as a placeholder to generate docs.
+ Returns a dummy Params instance used as a placeholder to
+ generate docs.
"""
dummy = Params()
dummy.uid = "undefined"
@@ -186,7 +222,7 @@ class Params(Identifiable):
Sets user-supplied params.
"""
for param, value in kwargs.items():
- self.paramMap[getattr(self, param)] = value
+ self._paramMap[getattr(self, param)] = value
return self
def _setDefault(self, **kwargs):
@@ -194,5 +230,19 @@ class Params(Identifiable):
Sets default params.
"""
for param, value in kwargs.items():
- self.defaultParamMap[getattr(self, param)] = value
+ self._defaultParamMap[getattr(self, param)] = value
return self
+
+ def _copyValues(self, to, extra={}):
+ """
+ Copies param values from this instance to another instance for
+ params shared by them.
+ :param to: the target instance
+ :param extra: extra params to be copied
+ :return: the target instance with param values copied
+ """
+ paramMap = self.extractParamMap(extra)
+ for p in self.params:
+ if p in paramMap and to.hasParam(p.name):
+ to._set(**{p.name: paramMap[p]})
+ return to