aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/_shared_params_code_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/_shared_params_code_gen.py')
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py91
1 files changed, 51 insertions, 40 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 5e297b8214..7dd2937db7 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -38,7 +38,7 @@ header = """#
# python _shared_params_code_gen.py > shared.py
-def _gen_param_header(name, doc, defaultValueStr, expectedType):
+def _gen_param_header(name, doc, defaultValueStr, typeConverter):
"""
Generates the header part for shared variables
@@ -50,7 +50,7 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType):
Mixin for param $name: $doc
"""
- $name = Param(Params._dummy(), "$name", "$doc", $expectedType)
+ $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter)
def __init__(self):
super(Has$Name, self).__init__()'''
@@ -60,15 +60,14 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType):
self._setDefault($name=$defaultValueStr)'''
Name = name[0].upper() + name[1:]
- expectedTypeName = str(expectedType)
- if expectedType is not None:
- expectedTypeName = expectedType.__name__
+ if typeConverter is None:
+ typeConverter = str(None)
return template \
.replace("$name", name) \
.replace("$Name", Name) \
.replace("$doc", doc) \
.replace("$defaultValueStr", str(defaultValueStr)) \
- .replace("$expectedType", expectedTypeName)
+ .replace("$typeConverter", typeConverter)
def _gen_param_code(name, doc, defaultValueStr):
@@ -105,64 +104,73 @@ def _gen_param_code(name, doc, defaultValueStr):
if __name__ == "__main__":
print(header)
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
- print("from pyspark.ml.param import Param, Params\n\n")
+ print("from pyspark.ml.param import *\n\n")
shared = [
- ("maxIter", "max number of iterations (>= 0).", None, int),
- ("regParam", "regularization parameter (>= 0).", None, float),
- ("featuresCol", "features column name.", "'features'", str),
- ("labelCol", "label column name.", "'label'", str),
- ("predictionCol", "prediction column name.", "'prediction'", str),
+ ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"),
+ ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"),
+ ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
+ ("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
+ ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
("probabilityCol", "Column name for predicted class conditional probabilities. " +
"Note: Not all models output well-calibrated probability estimates! These probabilities " +
- "should be treated as confidences, not precise probabilities.", "'probability'", str),
+ "should be treated as confidences, not precise probabilities.", "'probability'",
+ "TypeConverters.toString"),
("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'",
- str),
- ("inputCol", "input column name.", None, str),
- ("inputCols", "input column names.", None, None),
- ("outputCol", "output column name.", "self.uid + '__output'", str),
- ("numFeatures", "number of features.", None, int),
+ "TypeConverters.toString"),
+ ("inputCol", "input column name.", None, "TypeConverters.toString"),
+ ("inputCols", "input column names.", None, "TypeConverters.toListString"),
+ ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
+ ("numFeatures", "number of features.", None, "TypeConverters.toInt"),
("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
- "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int),
- ("seed", "random seed.", "hash(type(self).__name__)", int),
- ("tol", "the convergence tolerance for iterative algorithms.", None, float),
- ("stepSize", "Step size to be used for each iteration of optimization.", None, float),
+ "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None,
+ "TypeConverters.toInt"),
+ ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"),
+ ("tol", "the convergence tolerance for iterative algorithms.", None,
+ "TypeConverters.toFloat"),
+ ("stepSize", "Step size to be used for each iteration of optimization.", None,
+ "TypeConverters.toFloat"),
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
"out rows with bad values), or error (which will throw an errror). More options may be " +
- "added later.", None, str),
+ "added later.", None, "TypeConverters.toBoolean"),
("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
- "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float),
- ("fitIntercept", "whether to fit an intercept term.", "True", bool),
+ "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0",
+ "TypeConverters.toFloat"),
+ ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"),
("standardization", "whether to standardize the training features before fitting the " +
- "model.", "True", bool),
+ "model.", "True", "TypeConverters.toBoolean"),
("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
"predicting each class. Array must have length equal to the number of classes, with " +
"values >= 0. The class with largest value p/t is predicted, where p is the original " +
- "probability of that class and t is the class' threshold.", None, None),
+ "probability of that class and t is the class' threshold.", None,
+ "TypeConverters.toListFloat"),
("weightCol", "weight column name. If this is not set or empty, we treat " +
- "all instance weights as 1.0.", None, str),
+ "all instance weights as 1.0.", None, "TypeConverters.toString"),
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
- "default value is 'auto'.", "'auto'", str)]
+ "default value is 'auto'.", "'auto'", "TypeConverters.toString")]
code = []
- for name, doc, defaultValueStr, expectedType in shared:
- param_code = _gen_param_header(name, doc, defaultValueStr, expectedType)
+ for name, doc, defaultValueStr, typeConverter in shared:
+ param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter)
code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr))
decisionTreeParams = [
("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " +
- "depth 1 means 1 internal node + 2 leaf nodes."),
+ "depth 1 means 1 internal node + 2 leaf nodes.", "TypeConverters.toInt"),
("maxBins", "Max number of bins for" +
" discretizing continuous features. Must be >=2 and >= number of categories for any" +
- " categorical feature."),
+ " categorical feature.", "TypeConverters.toInt"),
("minInstancesPerNode", "Minimum number of instances each child must have after split. " +
"If a split causes the left or right child to have fewer than minInstancesPerNode, the " +
- "split will be discarded as invalid. Should be >= 1."),
- ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."),
- ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."),
+ "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"),
+ ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.",
+ "TypeConverters.toFloat"),
+ ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.",
+ "TypeConverters.toInt"),
("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " +
"instances with nodes. If true, the algorithm will cache node IDs for each instance. " +
"Caching can speed up training of deeper trees. Users can set how often should the " +
- "cache be checkpointed or disable it by setting checkpointInterval.")]
+ "cache be checkpointed or disable it by setting checkpointInterval.",
+ "TypeConverters.toBoolean")]
decisionTreeCode = '''class DecisionTreeParams(Params):
"""
@@ -175,9 +183,12 @@ if __name__ == "__main__":
super(DecisionTreeParams, self).__init__()'''
dtParamMethods = ""
dummyPlaceholders = ""
- paramTemplate = """$name = Param($owner, "$name", "$doc")"""
- for name, doc in decisionTreeParams:
- variable = paramTemplate.replace("$name", name).replace("$doc", doc)
+ paramTemplate = """$name = Param($owner, "$name", "$doc", typeConverter=$typeConverterStr)"""
+ for name, doc, typeConverterStr in decisionTreeParams:
+ if typeConverterStr is None:
+ typeConverterStr = str(None)
+ variable = paramTemplate.replace("$name", name).replace("$doc", doc) \
+ .replace("$typeConverterStr", typeConverterStr)
dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n "
dtParamMethods += _gen_param_code(name, doc, None) + "\n"
code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" +