aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/param/__init__.py
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py181
1 files changed, 163 insertions, 18 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index c0f0a71eb6..a1265294a1 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -14,31 +14,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import array
+import sys
+if sys.version > '3':
+ basestring = str
+ xrange = range
+ unicode = str
from abc import ABCMeta
import copy
+import numpy as np
+import warnings
from pyspark import since
from pyspark.ml.util import Identifiable
+from pyspark.mllib.linalg import DenseVector, Vector
-__all__ = ['Param', 'Params']
+__all__ = ['Param', 'Params', 'TypeConverters']
class Param(object):
"""
A param with self-contained documentation.
+ Note: `expectedType` is deprecated and will be removed in 2.1. Use typeConverter instead,
+ as a keyword argument.
+
.. versionadded:: 1.3.0
"""
- def __init__(self, parent, name, doc, expectedType=None):
+ def __init__(self, parent, name, doc, expectedType=None, typeConverter=None):
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
self.expectedType = expectedType
+ if expectedType is not None:
+ warnings.warn("expectedType is deprecated and will be removed in 2.1. " +
+ "Use typeConverter instead, as a keyword argument.")
+ self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
def _copy_new_parent(self, parent):
"""Copy the current param to a new parent, must be a dummy param."""
@@ -65,6 +81,146 @@ class Param(object):
return False
+class TypeConverters(object):
+ """
+ .. note:: DeveloperApi
+
+ Factory methods for common type conversion functions for `Param.typeConverter`.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @staticmethod
+ def _is_numeric(value):
+ vtype = type(value)
+ return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
+
+ @staticmethod
+ def _is_integer(value):
+ return TypeConverters._is_numeric(value) and float(value).is_integer()
+
+ @staticmethod
+ def _can_convert_to_list(value):
+ vtype = type(value)
+ return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
+
+ @staticmethod
+ def _can_convert_to_string(value):
+ vtype = type(value)
+ return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
+
+ @staticmethod
+ def identity(value):
+ """
+ Dummy converter that just returns value.
+ """
+ return value
+
+ @staticmethod
+ def toList(value):
+ """
+ Convert a value to a list, if possible.
+ """
+ if type(value) == list:
+ return value
+ elif type(value) in [np.ndarray, tuple, xrange, array.array]:
+ return list(value)
+ elif isinstance(value, Vector):
+ return list(value.toArray())
+ else:
+ raise TypeError("Could not convert %s to list" % value)
+
+ @staticmethod
+ def toListFloat(value):
+ """
+ Convert a value to list of floats, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return [float(v) for v in value]
+ raise TypeError("Could not convert %s to list of floats" % value)
+
+ @staticmethod
+ def toListInt(value):
+ """
+ Convert a value to list of ints, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_integer(v), value)):
+ return [int(v) for v in value]
+ raise TypeError("Could not convert %s to list of ints" % value)
+
+ @staticmethod
+ def toListString(value):
+ """
+ Convert a value to list of strings, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
+ return [TypeConverters.toString(v) for v in value]
+ raise TypeError("Could not convert %s to list of strings" % value)
+
+ @staticmethod
+ def toVector(value):
+ """
+ Convert a value to a MLlib Vector, if possible.
+ """
+ if isinstance(value, Vector):
+ return value
+ elif TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return DenseVector(value)
+ raise TypeError("Could not convert %s to vector" % value)
+
+ @staticmethod
+ def toFloat(value):
+ """
+ Convert a value to a float, if possible.
+ """
+ if TypeConverters._is_numeric(value):
+ return float(value)
+ else:
+ raise TypeError("Could not convert %s to float" % value)
+
+ @staticmethod
+ def toInt(value):
+ """
+ Convert a value to an int, if possible.
+ """
+ if TypeConverters._is_integer(value):
+ return int(value)
+ else:
+ raise TypeError("Could not convert %s to int" % value)
+
+ @staticmethod
+ def toString(value):
+ """
+ Convert a value to a string, if possible.
+ """
+ if isinstance(value, basestring):
+ return value
+ elif type(value) in [np.string_, np.str_]:
+ return str(value)
+ elif type(value) == np.unicode_:
+ return unicode(value)
+ else:
+ raise TypeError("Could not convert %s to string type" % type(value))
+
+ @staticmethod
+ def toBoolean(value):
+ """
+ Convert a value to a boolean, if possible.
+ """
+ if type(value) == bool:
+ return value
+ else:
+ raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
+
+
class Params(Identifiable):
"""
Components that take parameters. This also provides an internal
@@ -275,23 +431,12 @@ class Params(Identifiable):
"""
for param, value in kwargs.items():
p = getattr(self, param)
- if p.expectedType is None or type(value) == p.expectedType or value is None:
- self._paramMap[getattr(self, param)] = value
- else:
+ if value is not None:
try:
- # Try and do "safe" conversions that don't lose information
- if p.expectedType == float:
- self._paramMap[getattr(self, param)] = float(value)
- # Python 3 unified long & int
- elif p.expectedType == int and type(value).__name__ == 'long':
- self._paramMap[getattr(self, param)] = value
- else:
- raise Exception(
- "Provided type {0} incompatible with type {1} for param {2}"
- .format(type(value), p.expectedType, p))
- except ValueError:
- raise Exception(("Failed to convert {0} to type {1} for param {2}"
- .format(type(value), p.expectedType, p)))
+ value = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
+ self._paramMap[p] = value
return self
def _setDefault(self, **kwargs):