aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/__init__.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-01-06 10:43:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-06 10:43:03 -0800
commit3b29004d2439c03a7d9bfdf7c2edd757d3d8c240 (patch)
tree66fb557cad74f3e1507ecbcd1113bc913fac8fbc /python/pyspark/ml/param/__init__.py
parent9061e777fdcd5767718808e325e8953d484aa761 (diff)
downloadspark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.gz
spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.bz2
spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.zip
[SPARK-7675][ML][PYSPARK] sparkml params type conversion
From JIRA: Currently, PySpark wrappers for spark.ml Scala classes are brittle when accepting Param types. E.g., Normalizer's "p" param cannot be set to "2" (an integer); it must be set to "2.0" (a float). Fixing this is not trivial since there does not appear to be a natural place to insert the conversion before Python wrappers call Java's Params setter method. A possible fix will be to include a method "_checkType" to PySpark's Param class which checks the type, prints an error if needed, and converts types when relevant (e.g., int to float, or scipy matrix to array). The Java wrapper method which copies params to Scala can call this method when available. This fix instead checks the types at set time since I think failing sooner is better, but I can switch it around to check at copy time if that would be better. So far this only converts int to float and other conversions (like scipymatrix to array) are left for the future. Author: Holden Karau <holden@us.ibm.com> Closes #9581 from holdenk/SPARK-7675-PySpark-sparkml-Params-type-conversion.
Diffstat (limited to 'python/pyspark/ml/param/__init__.py')
-rw-r--r--python/pyspark/ml/param/__init__.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 35c9b776a3..92ce96aa3c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -32,12 +32,13 @@ class Param(object):
.. versionadded:: 1.3.0
"""
- def __init__(self, parent, name, doc):
+ def __init__(self, parent, name, doc, expectedType=None):
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
+ self.expectedType = expectedType
def __str__(self):
return str(self.parent) + "__" + self.name
@@ -247,7 +248,24 @@ class Params(Identifiable):
Sets user-supplied params.
"""
for param, value in kwargs.items():
- self._paramMap[getattr(self, param)] = value
+ p = getattr(self, param)
+ if p.expectedType is None or type(value) == p.expectedType or value is None:
+ self._paramMap[getattr(self, param)] = value
+ else:
+ try:
+ # Try and do "safe" conversions that don't lose information
+ if p.expectedType == float:
+ self._paramMap[getattr(self, param)] = float(value)
+ # Python 3 unified long & int
+ elif p.expectedType == int and type(value).__name__ == 'long':
+ self._paramMap[getattr(self, param)] = value
+ else:
+ raise Exception(
+ "Provided type {0} incompatible with type {1} for param {2}"
+ .format(type(value), p.expectedType, p))
+ except ValueError:
+ raise Exception(("Failed to convert {0} to type {1} for param {2}"
+ .format(type(value), p.expectedType, p)))
return self
def _setDefault(self, **kwargs):