aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py181
1 files changed, 163 insertions, 18 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index c0f0a71eb6..a1265294a1 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -14,31 +14,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import array
+import sys
+if sys.version > '3':
+ basestring = str
+ xrange = range
+ unicode = str
from abc import ABCMeta
import copy
+import numpy as np
+import warnings
from pyspark import since
from pyspark.ml.util import Identifiable
+from pyspark.mllib.linalg import DenseVector, Vector
-__all__ = ['Param', 'Params']
+__all__ = ['Param', 'Params', 'TypeConverters']
class Param(object):
"""
A param with self-contained documentation.
+ Note: `expectedType` is deprecated and will be removed in 2.1. Use typeConverter instead,
+ as a keyword argument.
+
.. versionadded:: 1.3.0
"""
- def __init__(self, parent, name, doc, expectedType=None):
+ def __init__(self, parent, name, doc, expectedType=None, typeConverter=None):
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
self.expectedType = expectedType
+ if expectedType is not None:
+ warnings.warn("expectedType is deprecated and will be removed in 2.1. " +
+ "Use typeConverter instead, as a keyword argument.")
+ self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
def _copy_new_parent(self, parent):
"""Copy the current param to a new parent, must be a dummy param."""
@@ -65,6 +81,146 @@ class Param(object):
return False
+class TypeConverters(object):
+ """
+ .. note:: DeveloperApi
+
+ Factory methods for common type conversion functions for `Param.typeConverter`.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @staticmethod
+ def _is_numeric(value):
+ vtype = type(value)
+ return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
+
+ @staticmethod
+ def _is_integer(value):
+ return TypeConverters._is_numeric(value) and float(value).is_integer()
+
+ @staticmethod
+ def _can_convert_to_list(value):
+ vtype = type(value)
+ return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
+
+ @staticmethod
+ def _can_convert_to_string(value):
+ vtype = type(value)
+ return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
+
+ @staticmethod
+ def identity(value):
+ """
+ Dummy converter that just returns value.
+ """
+ return value
+
+ @staticmethod
+ def toList(value):
+ """
+ Convert a value to a list, if possible.
+ """
+ if type(value) == list:
+ return value
+ elif type(value) in [np.ndarray, tuple, xrange, array.array]:
+ return list(value)
+ elif isinstance(value, Vector):
+ return list(value.toArray())
+ else:
+ raise TypeError("Could not convert %s to list" % value)
+
+ @staticmethod
+ def toListFloat(value):
+ """
+ Convert a value to list of floats, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return [float(v) for v in value]
+ raise TypeError("Could not convert %s to list of floats" % value)
+
+ @staticmethod
+ def toListInt(value):
+ """
+ Convert a value to list of ints, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_integer(v), value)):
+ return [int(v) for v in value]
+ raise TypeError("Could not convert %s to list of ints" % value)
+
+ @staticmethod
+ def toListString(value):
+ """
+ Convert a value to list of strings, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
+ return [TypeConverters.toString(v) for v in value]
+ raise TypeError("Could not convert %s to list of strings" % value)
+
+ @staticmethod
+ def toVector(value):
+ """
+ Convert a value to a MLlib Vector, if possible.
+ """
+ if isinstance(value, Vector):
+ return value
+ elif TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return DenseVector(value)
+ raise TypeError("Could not convert %s to vector" % value)
+
+ @staticmethod
+ def toFloat(value):
+ """
+ Convert a value to a float, if possible.
+ """
+ if TypeConverters._is_numeric(value):
+ return float(value)
+ else:
+ raise TypeError("Could not convert %s to float" % value)
+
+ @staticmethod
+ def toInt(value):
+ """
+ Convert a value to an int, if possible.
+ """
+ if TypeConverters._is_integer(value):
+ return int(value)
+ else:
+ raise TypeError("Could not convert %s to int" % value)
+
+ @staticmethod
+ def toString(value):
+ """
+ Convert a value to a string, if possible.
+ """
+ if isinstance(value, basestring):
+ return value
+ elif type(value) in [np.string_, np.str_]:
+ return str(value)
+ elif type(value) == np.unicode_:
+ return unicode(value)
+ else:
+ raise TypeError("Could not convert %s to string type" % type(value))
+
+ @staticmethod
+ def toBoolean(value):
+ """
+ Convert a value to a boolean, if possible.
+ """
+ if type(value) == bool:
+ return value
+ else:
+ raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
+
+
class Params(Identifiable):
"""
Components that take parameters. This also provides an internal
@@ -275,23 +431,12 @@ class Params(Identifiable):
"""
for param, value in kwargs.items():
p = getattr(self, param)
- if p.expectedType is None or type(value) == p.expectedType or value is None:
- self._paramMap[getattr(self, param)] = value
- else:
+ if value is not None:
try:
- # Try and do "safe" conversions that don't lose information
- if p.expectedType == float:
- self._paramMap[getattr(self, param)] = float(value)
- # Python 3 unified long & int
- elif p.expectedType == int and type(value).__name__ == 'long':
- self._paramMap[getattr(self, param)] = value
- else:
- raise Exception(
- "Provided type {0} incompatible with type {1} for param {2}"
- .format(type(value), p.expectedType, p))
- except ValueError:
- raise Exception(("Failed to convert {0} to type {1} for param {2}"
- .format(type(value), p.expectedType, p)))
+ value = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
+ self._paramMap[p] = value
return self
def _setDefault(self, **kwargs):