diff options
Diffstat (limited to 'python/pyspark/ml/param/_shared_params_code_gen.py')
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 91 |
1 files changed, 51 insertions, 40 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 5e297b8214..7dd2937db7 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -38,7 +38,7 @@ header = """# # python _shared_params_code_gen.py > shared.py -def _gen_param_header(name, doc, defaultValueStr, expectedType): +def _gen_param_header(name, doc, defaultValueStr, typeConverter): """ Generates the header part for shared variables @@ -50,7 +50,7 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType): Mixin for param $name: $doc """ - $name = Param(Params._dummy(), "$name", "$doc", $expectedType) + $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter) def __init__(self): super(Has$Name, self).__init__()''' @@ -60,15 +60,14 @@ def _gen_param_header(name, doc, defaultValueStr, expectedType): self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] - expectedTypeName = str(expectedType) - if expectedType is not None: - expectedTypeName = expectedType.__name__ + if typeConverter is None: + typeConverter = str(None) return template \ .replace("$name", name) \ .replace("$Name", Name) \ .replace("$doc", doc) \ .replace("$defaultValueStr", str(defaultValueStr)) \ - .replace("$expectedType", expectedTypeName) + .replace("$typeConverter", typeConverter) def _gen_param_code(name, doc, defaultValueStr): @@ -105,64 +104,73 @@ def _gen_param_code(name, doc, defaultValueStr): if __name__ == "__main__": print(header) print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") - print("from pyspark.ml.param import Param, Params\n\n") + print("from pyspark.ml.param import *\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0).", None, int), - ("regParam", "regularization parameter (>= 0).", None, float), - ("featuresCol", "features column name.", "'features'", str), - ("labelCol", "label column name.", "'label'", str), - ("predictionCol", "prediction column name.", "'prediction'", str), + ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"), + ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"), + ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"), + ("labelCol", "label column name.", "'label'", "TypeConverters.toString"), + ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + - "should be treated as confidences, not precise probabilities.", "'probability'", str), + "should be treated as confidences, not precise probabilities.", "'probability'", + "TypeConverters.toString"), ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'", - str), - ("inputCol", "input column name.", None, str), - ("inputCols", "input column names.", None, None), - ("outputCol", "output column name.", "self.uid + '__output'", str), - ("numFeatures", "number of features.", None, int), + "TypeConverters.toString"), + ("inputCol", "input column name.", None, "TypeConverters.toString"), + ("inputCols", "input column names.", None, "TypeConverters.toListString"), + ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int), - ("seed", "random seed.", "hash(type(self).__name__)", int), - ("tol", "the convergence tolerance for iterative algorithms.", None, float), - ("stepSize", "Step size to be used for each iteration of optimization.", None, float), + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, + "TypeConverters.toInt"), + ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), + ("tol", "the convergence tolerance for iterative algorithms.", None, + "TypeConverters.toFloat"), + ("stepSize", "Step size to be used for each iteration of optimization.", None, + "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None, str), + "added later.", None, "TypeConverters.toBoolean"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float), - ("fitIntercept", "whether to fit an intercept term.", "True", bool), + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", + "TypeConverters.toFloat"), + ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"), ("standardization", "whether to standardize the training features before fitting the " + - "model.", "True", bool), + "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None, None), + "probability of that class and t is the class' threshold.", None, + "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None, str), + "all instance weights as 1.0.", None, "TypeConverters.toString"), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'", str)] + "default value is 'auto'.", "'auto'", "TypeConverters.toString")] code = [] - for name, doc, defaultValueStr, expectedType in shared: - param_code = _gen_param_header(name, doc, defaultValueStr, expectedType) + for name, doc, defaultValueStr, typeConverter in shared: + param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter) code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) decisionTreeParams = [ ("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " + - "depth 1 means 1 internal node + 2 leaf nodes."), + "depth 1 means 1 internal node + 2 leaf nodes.", "TypeConverters.toInt"), ("maxBins", "Max number of bins for" + " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature."), + " categorical feature.", "TypeConverters.toInt"), ("minInstancesPerNode", "Minimum number of instances each child must have after split. " + "If a split causes the left or right child to have fewer than minInstancesPerNode, the " + - "split will be discarded as invalid. Should be >= 1."), - ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."), - ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), + "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"), + ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.", + "TypeConverters.toFloat"), + ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", + "TypeConverters.toInt"), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + "Caching can speed up training of deeper trees. Users can set how often should the " + - "cache be checkpointed or disable it by setting checkpointInterval.")] + "cache be checkpointed or disable it by setting checkpointInterval.", + "TypeConverters.toBoolean")] decisionTreeCode = '''class DecisionTreeParams(Params): """ @@ -175,9 +183,12 @@ if __name__ == "__main__": super(DecisionTreeParams, self).__init__()''' dtParamMethods = "" dummyPlaceholders = "" - paramTemplate = """$name = Param($owner, "$name", "$doc")""" - for name, doc in decisionTreeParams: - variable = paramTemplate.replace("$name", name).replace("$doc", doc) + paramTemplate = """$name = Param($owner, "$name", "$doc", typeConverter=$typeConverterStr)""" + for name, doc, typeConverterStr in decisionTreeParams: + if typeConverterStr is None: + typeConverterStr = str(None) + variable = paramTemplate.replace("$name", name).replace("$doc", doc) \ + .replace("$typeConverterStr", typeConverterStr) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" + |