aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 35c9b776a3..92ce96aa3c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -32,12 +32,13 @@ class Param(object):
.. versionadded:: 1.3.0
"""
- def __init__(self, parent, name, doc):
+ def __init__(self, parent, name, doc, expectedType=None):
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
+ self.expectedType = expectedType
def __str__(self):
return str(self.parent) + "__" + self.name
@@ -247,7 +248,24 @@ class Params(Identifiable):
Sets user-supplied params.
"""
for param, value in kwargs.items():
- self._paramMap[getattr(self, param)] = value
+ p = getattr(self, param)
+ if p.expectedType is None or type(value) == p.expectedType or value is None:
+ self._paramMap[getattr(self, param)] = value
+ else:
+ try:
+ # Try and do "safe" conversions that don't lose information
+ if p.expectedType == float:
+ self._paramMap[getattr(self, param)] = float(value)
+ # Python 3 unified long & int
+ elif p.expectedType == int and type(value).__name__ == 'long':
+ self._paramMap[getattr(self, param)] = value
+ else:
+ raise Exception(
+ "Provided type {0} incompatible with type {1} for param {2}"
+ .format(type(value), p.expectedType, p))
+ except ValueError:
+ raise Exception(("Failed to convert {0} to type {1} for param {2}"
+ .format(type(value), p.expectedType, p)))
return self
def _setDefault(self, **kwargs):