aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/_shared_params_code_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/_shared_params_code_gen.py')
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py69
1 files changed, 62 insertions, 7 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 3be0979b92..4a5cc6e64f 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -38,16 +38,13 @@ header = """#
# python _shared_params_code_gen.py > shared.py
-def _gen_param_code(name, doc, defaultValueStr):
+def _gen_param_header(name, doc, defaultValueStr):
"""
- Generates Python code for a shared param class.
+ Generates the header part for shared variables
:param name: param name
:param doc: param doc
- :param defaultValueStr: string representation of the default value
- :return: code string
"""
- # TODO: How to correctly inherit instance attributes?
template = '''class Has$Name(Params):
"""
Mixin for param $name: $doc.
@@ -61,8 +58,27 @@ def _gen_param_code(name, doc, defaultValueStr):
#: param for $doc
self.$name = Param(self, "$name", "$doc")
if $defaultValueStr is not None:
- self._setDefault($name=$defaultValueStr)
+ self._setDefault($name=$defaultValueStr)'''
+
+ Name = name[0].upper() + name[1:]
+ return template \
+ .replace("$name", name) \
+ .replace("$Name", Name) \
+ .replace("$doc", doc) \
+ .replace("$defaultValueStr", str(defaultValueStr))
+
+def _gen_param_code(name, doc, defaultValueStr):
+ """
+ Generates Python code for a shared param class.
+
+ :param name: param name
+ :param doc: param doc
+ :param defaultValueStr: string representation of the default value
+ :return: code string
+ """
+ # TODO: How to correctly inherit instance attributes?
+ template = '''
def set$Name(self, value):
"""
Sets the value of :py:attr:`$name`.
@@ -104,5 +120,44 @@ if __name__ == "__main__":
("stepSize", "Step size to be used for each iteration of optimization.", None)]
code = []
for name, doc, defaultValueStr in shared:
- code.append(_gen_param_code(name, doc, defaultValueStr))
+ param_code = _gen_param_header(name, doc, defaultValueStr)
+ code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr))
+
+ decisionTreeParams = [
+ ("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " +
+ "depth 1 means 1 internal node + 2 leaf nodes."),
+ ("maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature."),
+ ("minInstancesPerNode", "Minimum number of instances each child must have after split. " +
+ "If a split causes the left or right child to have fewer than minInstancesPerNode, the " +
+ "split will be discarded as invalid. Should be >= 1."),
+ ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."),
+ ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."),
+ ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " +
+ "instances with nodes. If true, the algorithm will cache node IDs for each instance. " +
+ "Caching can speed up training of deeper trees.")]
+
+ decisionTreeCode = '''class DecisionTreeParams(Params):
+ """
+ Mixin for Decision Tree parameters.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ $dummyPlaceHolders
+
+ def __init__(self):
+ super(DecisionTreeParams, self).__init__()
+ $realParams'''
+ dtParamMethods = ""
+ dummyPlaceholders = ""
+ realParams = ""
+ paramTemplate = """$name = Param($owner, "$name", "$doc")"""
+ for name, doc in decisionTreeParams:
+ variable = paramTemplate.replace("$name", name).replace("$doc", doc)
+ dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n "
+ realParams += "self." + variable.replace("$owner", "self") + "\n "
+ dtParamMethods += _gen_param_code(name, doc, None) + "\n"
+ code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders)
+ .replace("$realParams", realParams) + dtParamMethods)
print("\n\n\n".join(code))