aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/param
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r--python/pyspark/ml/param/__init__.py181
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py91
-rw-r--r--python/pyspark/ml/param/shared.py58
3 files changed, 243 insertions, 87 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index c0f0a71eb6..a1265294a1 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -14,31 +14,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import array
+import sys
+if sys.version > '3':
+ basestring = str
+ xrange = range
+ unicode = str
from abc import ABCMeta
import copy
+import numpy as np
+import warnings
from pyspark import since
from pyspark.ml.util import Identifiable
+from pyspark.mllib.linalg import DenseVector, Vector
-__all__ = ['Param', 'Params']
+__all__ = ['Param', 'Params', 'TypeConverters']
class Param(object):
"""
A param with self-contained documentation.
+ Note: `expectedType` is deprecated and will be removed in 2.1. Use typeConverter instead,
+ as a keyword argument.
+
.. versionadded:: 1.3.0
"""
- def __init__(self, parent, name, doc, expectedType=None):
+ def __init__(self, parent, name, doc, expectedType=None, typeConverter=None):
if not isinstance(parent, Identifiable):
raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
self.parent = parent.uid
self.name = str(name)
self.doc = str(doc)
self.expectedType = expectedType
+ if expectedType is not None:
+ warnings.warn("expectedType is deprecated and will be removed in 2.1. " +
+ "Use typeConverter instead, as a keyword argument.")
+ self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
def _copy_new_parent(self, parent):
"""Copy the current param to a new parent, must be a dummy param."""
@@ -65,6 +81,146 @@ class Param(object):
return False
+class TypeConverters(object):
+ """
+ .. note:: DeveloperApi
+
+ Factory methods for common type conversion functions for `Param.typeConverter`.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @staticmethod
+ def _is_numeric(value):
+ vtype = type(value)
+ return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
+
+ @staticmethod
+ def _is_integer(value):
+ return TypeConverters._is_numeric(value) and float(value).is_integer()
+
+ @staticmethod
+ def _can_convert_to_list(value):
+ vtype = type(value)
+ return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
+
+ @staticmethod
+ def _can_convert_to_string(value):
+ vtype = type(value)
+ return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
+
+ @staticmethod
+ def identity(value):
+ """
+ Dummy converter that just returns value.
+ """
+ return value
+
+ @staticmethod
+ def toList(value):
+ """
+ Convert a value to a list, if possible.
+ """
+ if type(value) == list:
+ return value
+ elif type(value) in [np.ndarray, tuple, xrange, array.array]:
+ return list(value)
+ elif isinstance(value, Vector):
+ return list(value.toArray())
+ else:
+ raise TypeError("Could not convert %s to list" % value)
+
+ @staticmethod
+ def toListFloat(value):
+ """
+ Convert a value to list of floats, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return [float(v) for v in value]
+ raise TypeError("Could not convert %s to list of floats" % value)
+
+ @staticmethod
+ def toListInt(value):
+ """
+ Convert a value to list of ints, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_integer(v), value)):
+ return [int(v) for v in value]
+ raise TypeError("Could not convert %s to list of ints" % value)
+
+ @staticmethod
+ def toListString(value):
+ """
+ Convert a value to list of strings, if possible.
+ """
+ if TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
+ return [TypeConverters.toString(v) for v in value]
+ raise TypeError("Could not convert %s to list of strings" % value)
+
+ @staticmethod
+ def toVector(value):
+ """
+ Convert a value to a MLlib Vector, if possible.
+ """
+ if isinstance(value, Vector):
+ return value
+ elif TypeConverters._can_convert_to_list(value):
+ value = TypeConverters.toList(value)
+ if all(map(lambda v: TypeConverters._is_numeric(v), value)):
+ return DenseVector(value)
+ raise TypeError("Could not convert %s to vector" % value)
+
+ @staticmethod
+ def toFloat(value):
+ """
+ Convert a value to a float, if possible.
+ """
+ if TypeConverters._is_numeric(value):
+ return float(value)
+ else:
+ raise TypeError("Could not convert %s to float" % value)
+
+ @staticmethod
+ def toInt(value):
+ """
+ Convert a value to an int, if possible.
+ """
+ if TypeConverters._is_integer(value):
+ return int(value)
+ else:
+ raise TypeError("Could not convert %s to int" % value)
+
+ @staticmethod
+ def toString(value):
+ """
+ Convert a value to a string, if possible.
+ """
+ if isinstance(value, basestring):
+ return value
+ elif type(value) in [np.string_, np.str_]:
+ return str(value)
+ elif type(value) == np.unicode_:
+ return unicode(value)
+ else:
+ raise TypeError("Could not convert %s to string type" % type(value))
+
+ @staticmethod
+ def toBoolean(value):
+ """
+ Convert a value to a boolean, if possible.
+ """
+ if type(value) == bool:
+ return value
+ else:
+ raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
+
+
class Params(Identifiable):
"""
Components that take parameters. This also provides an internal
@@ -275,23 +431,12 @@ class Params(Identifiable):
"""
for param, value in kwargs.items():
p = getattr(self, param)
- if p.expectedType is None or type(value) == p.expectedType or value is None:
- self._paramMap[getattr(self, param)] = value
- else:
+ if value is not None:
try:
- # Try and do "safe" conversions that don't lose information
- if p.expectedType == float:
- self._paramMap[getattr(self, param)] = float(value)
- # Python 3 unified long & int
- elif p.expectedType == int and type(value).__name__ == 'long':
- self._paramMap[getattr(self, param)] = value
- else:
- raise Exception(
- "Provided type {0} incompatible with type {1} for param {2}"
- .format(type(value), p.expectedType, p))
- except ValueError:
- raise Exception(("Failed to convert {0} to type {1} for param {2}"
- .format(type(value), p.expectedType, p)))
+ value = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
+ self._paramMap[p] = value
return self
def _setDefault(self, **kwargs):
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 5e297b8214..7dd2937db7 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -38,7 +38,7 @@ header = """#
# python _shared_params_code_gen.py > shared.py
-def _gen_param_header(name, doc, defaultValueStr, expectedType):
+def _gen_param_header(name, doc, defaultValueStr, typeConverter):
"""
Generates the header part for shared variables
@@ -50,7 +50,7 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType):
Mixin for param $name: $doc
"""
- $name = Param(Params._dummy(), "$name", "$doc", $expectedType)
+ $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter)
def __init__(self):
super(Has$Name, self).__init__()'''
@@ -60,15 +60,14 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType):
self._setDefault($name=$defaultValueStr)'''
Name = name[0].upper() + name[1:]
- expectedTypeName = str(expectedType)
- if expectedType is not None:
- expectedTypeName = expectedType.__name__
+ if typeConverter is None:
+ typeConverter = str(None)
return template \
.replace("$name", name) \
.replace("$Name", Name) \
.replace("$doc", doc) \
.replace("$defaultValueStr", str(defaultValueStr)) \
- .replace("$expectedType", expectedTypeName)
+ .replace("$typeConverter", typeConverter)
def _gen_param_code(name, doc, defaultValueStr):
@@ -105,64 +104,73 @@ def _gen_param_code(name, doc, defaultValueStr):
if __name__ == "__main__":
print(header)
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
- print("from pyspark.ml.param import Param, Params\n\n")
+ print("from pyspark.ml.param import *\n\n")
shared = [
- ("maxIter", "max number of iterations (>= 0).", None, int),
- ("regParam", "regularization parameter (>= 0).", None, float),
- ("featuresCol", "features column name.", "'features'", str),
- ("labelCol", "label column name.", "'label'", str),
- ("predictionCol", "prediction column name.", "'prediction'", str),
+ ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"),
+ ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"),
+ ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
+ ("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
+ ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
("probabilityCol", "Column name for predicted class conditional probabilities. " +
"Note: Not all models output well-calibrated probability estimates! These probabilities " +
- "should be treated as confidences, not precise probabilities.", "'probability'", str),
+ "should be treated as confidences, not precise probabilities.", "'probability'",
+ "TypeConverters.toString"),
("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'",
- str),
- ("inputCol", "input column name.", None, str),
- ("inputCols", "input column names.", None, None),
- ("outputCol", "output column name.", "self.uid + '__output'", str),
- ("numFeatures", "number of features.", None, int),
+ "TypeConverters.toString"),
+ ("inputCol", "input column name.", None, "TypeConverters.toString"),
+ ("inputCols", "input column names.", None, "TypeConverters.toListString"),
+ ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
+ ("numFeatures", "number of features.", None, "TypeConverters.toInt"),
("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
- "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int),
- ("seed", "random seed.", "hash(type(self).__name__)", int),
- ("tol", "the convergence tolerance for iterative algorithms.", None, float),
- ("stepSize", "Step size to be used for each iteration of optimization.", None, float),
+ "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None,
+ "TypeConverters.toInt"),
+ ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"),
+ ("tol", "the convergence tolerance for iterative algorithms.", None,
+ "TypeConverters.toFloat"),
+ ("stepSize", "Step size to be used for each iteration of optimization.", None,
+ "TypeConverters.toFloat"),
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
"out rows with bad values), or error (which will throw an errror). More options may be " +
- "added later.", None, str),
+ "added later.", None, "TypeConverters.toBoolean"),
("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float),
- ("fitIntercept", "whether to fit an intercept term.", "True", bool),
+ "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0",
+ "TypeConverters.toFloat"),
+ ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"),
("standardization", "whether to standardize the training features before fitting the " +
- "model.", "True", bool),
+ "model.", "True", "TypeConverters.toBoolean"),
("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
"predicting each class. Array must have length equal to the number of classes, with " +
"values >= 0. The class with largest value p/t is predicted, where p is the original " +
- "probability of that class and t is the class' threshold.", None, None),
+ "probability of that class and t is the class' threshold.", None,
+ "TypeConverters.toListFloat"),
("weightCol", "weight column name. If this is not set or empty, we treat " +
- "all instance weights as 1.0.", None, str),
+ "all instance weights as 1.0.", None, "TypeConverters.toString"),
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
- "default value is 'auto'.", "'auto'", str)]
+ "default value is 'auto'.", "'auto'", "TypeConverters.toString")]
code = []
- for name, doc, defaultValueStr, expectedType in shared:
- param_code = _gen_param_header(name, doc, defaultValueStr, expectedType)
+ for name, doc, defaultValueStr, typeConverter in shared:
+ param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter)
code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr))
decisionTreeParams = [
("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " +
- "depth 1 means 1 internal node + 2 leaf nodes."),
+ "depth 1 means 1 internal node + 2 leaf nodes.", "TypeConverters.toInt"),
("maxBins", "Max number of bins for" +
" discretizing continuous features. Must be >=2 and >= number of categories for any" +
- " categorical feature."),
+ " categorical feature.", "TypeConverters.toInt"),
("minInstancesPerNode", "Minimum number of instances each child must have after split. " +
"If a split causes the left or right child to have fewer than minInstancesPerNode, the " +
- "split will be discarded as invalid. Should be >= 1."),
- ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."),
- ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."),
+ "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"),
+ ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.",
+ "TypeConverters.toFloat"),
+ ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.",
+ "TypeConverters.toInt"),
("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " +
"instances with nodes. If true, the algorithm will cache node IDs for each instance. " +
"Caching can speed up training of deeper trees. Users can set how often should the " +
- "cache be checkpointed or disable it by setting checkpointInterval.")]
+ "cache be checkpointed or disable it by setting checkpointInterval.",
+ "TypeConverters.toBoolean")]
decisionTreeCode = '''class DecisionTreeParams(Params):
"""
@@ -175,9 +183,12 @@ if __name__ == "__main__":
super(DecisionTreeParams, self).__init__()'''
dtParamMethods = ""
dummyPlaceholders = ""
- paramTemplate = """$name = Param($owner, "$name", "$doc")"""
- for name, doc in decisionTreeParams:
- variable = paramTemplate.replace("$name", name).replace("$doc", doc)
+ paramTemplate = """$name = Param($owner, "$name", "$doc", typeConverter=$typeConverterStr)"""
+ for name, doc, typeConverterStr in decisionTreeParams:
+ if typeConverterStr is None:
+ typeConverterStr = str(None)
+ variable = paramTemplate.replace("$name", name).replace("$doc", doc) \
+ .replace("$typeConverterStr", typeConverterStr)
dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n "
dtParamMethods += _gen_param_code(name, doc, None) + "\n"
code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" +
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index db4a8a54d4..83fbd59039 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -17,7 +17,7 @@
# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.
-from pyspark.ml.param import Param, Params
+from pyspark.ml.param import *
class HasMaxIter(Params):
@@ -25,7 +25,7 @@ class HasMaxIter(Params):
Mixin for param maxIter: max number of iterations (>= 0).
"""
- maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int)
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasMaxIter, self).__init__()
@@ -49,7 +49,7 @@ class HasRegParam(Params):
Mixin for param regParam: regularization parameter (>= 0).
"""
- regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float)
+ regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasRegParam, self).__init__()
@@ -73,7 +73,7 @@ class HasFeaturesCol(Params):
Mixin for param featuresCol: features column name.
"""
- featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str)
+ featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasFeaturesCol, self).__init__()
@@ -98,7 +98,7 @@ class HasLabelCol(Params):
Mixin for param labelCol: label column name.
"""
- labelCol = Param(Params._dummy(), "labelCol", "label column name.", str)
+ labelCol = Param(Params._dummy(), "labelCol", "label column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasLabelCol, self).__init__()
@@ -123,7 +123,7 @@ class HasPredictionCol(Params):
Mixin for param predictionCol: prediction column name.
"""
- predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str)
+ predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasPredictionCol, self).__init__()
@@ -148,7 +148,7 @@ class HasProbabilityCol(Params):
Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
"""
- probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str)
+ probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasProbabilityCol, self).__init__()
@@ -173,7 +173,7 @@ class HasRawPredictionCol(Params):
Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name.
"""
- rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str)
+ rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasRawPredictionCol, self).__init__()
@@ -198,7 +198,7 @@ class HasInputCol(Params):
Mixin for param inputCol: input column name.
"""
- inputCol = Param(Params._dummy(), "inputCol", "input column name.", str)
+ inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasInputCol, self).__init__()
@@ -222,7 +222,7 @@ class HasInputCols(Params):
Mixin for param inputCols: input column names.
"""
- inputCols = Param(Params._dummy(), "inputCols", "input column names.", None)
+ inputCols = Param(Params._dummy(), "inputCols", "input column names.", typeConverter=TypeConverters.toListString)
def __init__(self):
super(HasInputCols, self).__init__()
@@ -246,7 +246,7 @@ class HasOutputCol(Params):
Mixin for param outputCol: output column name.
"""
- outputCol = Param(Params._dummy(), "outputCol", "output column name.", str)
+ outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasOutputCol, self).__init__()
@@ -271,7 +271,7 @@ class HasNumFeatures(Params):
Mixin for param numFeatures: number of features.
"""
- numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int)
+ numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasNumFeatures, self).__init__()
@@ -295,7 +295,7 @@ class HasCheckpointInterval(Params):
Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.
"""
- checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int)
+ checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasCheckpointInterval, self).__init__()
@@ -319,7 +319,7 @@ class HasSeed(Params):
Mixin for param seed: random seed.
"""
- seed = Param(Params._dummy(), "seed", "random seed.", int)
+ seed = Param(Params._dummy(), "seed", "random seed.", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasSeed, self).__init__()
@@ -344,7 +344,7 @@ class HasTol(Params):
Mixin for param tol: the convergence tolerance for iterative algorithms.
"""
- tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float)
+ tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasTol, self).__init__()
@@ -368,7 +368,7 @@ class HasStepSize(Params):
Mixin for param stepSize: Step size to be used for each iteration of optimization.
"""
- stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float)
+ stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasStepSize, self).__init__()
@@ -392,7 +392,7 @@ class HasHandleInvalid(Params):
Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
"""
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(HasHandleInvalid, self).__init__()
@@ -416,7 +416,7 @@ class HasElasticNetParam(Params):
Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
"""
- elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float)
+ elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasElasticNetParam, self).__init__()
@@ -441,7 +441,7 @@ class HasFitIntercept(Params):
Mixin for param fitIntercept: whether to fit an intercept term.
"""
- fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool)
+ fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(HasFitIntercept, self).__init__()
@@ -466,7 +466,7 @@ class HasStandardization(Params):
Mixin for param standardization: whether to standardize the training features before fitting the model.
"""
- standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool)
+ standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(HasStandardization, self).__init__()
@@ -491,7 +491,7 @@ class HasThresholds(Params):
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
"""
- thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None)
+ thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat)
def __init__(self):
super(HasThresholds, self).__init__()
@@ -515,7 +515,7 @@ class HasWeightCol(Params):
Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0.
"""
- weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str)
+ weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasWeightCol, self).__init__()
@@ -539,7 +539,7 @@ class HasSolver(Params):
Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
"""
- solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str)
+ solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasSolver, self).__init__()
@@ -564,12 +564,12 @@ class DecisionTreeParams(Params):
Mixin for Decision Tree parameters.
"""
- maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
- maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.")
- minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.")
- minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
- maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
- cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.")
+ maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", typeConverter=TypeConverters.toInt)
+ maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.", typeConverter=TypeConverters.toInt)
+ minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.", typeConverter=TypeConverters.toInt)
+ minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
+ maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", typeConverter=TypeConverters.toInt)
+ cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean)
def __init__(self):