aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r--python/pyspark/ml/param/__init__.py82
-rw-r--r--python/pyspark/ml/param/_gen_shared_params.py98
-rw-r--r--python/pyspark/ml/param/shared.py260
3 files changed, 440 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
new file mode 100644
index 0000000000..5566792cea
--- /dev/null
+++ b/python/pyspark/ml/param/__init__.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark.ml.util import Identifiable
+
+
+__all__ = ['Param', 'Params']
+
+
+class Param(object):
+ """
+ A param with self-contained documentation and optionally default value.
+ """
+
+ def __init__(self, parent, name, doc, defaultValue=None):
+ if not isinstance(parent, Identifiable):
+ raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+ self.parent = parent
+ self.name = str(name)
+ self.doc = str(doc)
+ self.defaultValue = defaultValue
+
+ def __str__(self):
+ return str(self.parent) + "-" + self.name
+
+ def __repr__(self):
+ return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
+ (self.parent, self.name, self.doc, self.defaultValue)
+
+
+class Params(Identifiable):
+ """
+ Components that take parameters. This also provides an internal
+ param map to store parameter values attached to the instance.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(Params, self).__init__()
+ #: embedded param map
+ self.paramMap = {}
+
+ @property
+ def params(self):
+ """
+ Returns all params. The default implementation uses
+ :py:func:`dir` to get all attributes of type
+ :py:class:`Param`.
+ """
+ return filter(lambda attr: isinstance(attr, Param),
+ [getattr(self, x) for x in dir(self) if x != "params"])
+
+ def _merge_params(self, params):
+ paramMap = self.paramMap.copy()
+ paramMap.update(params)
+ return paramMap
+
+ @staticmethod
+ def _dummy():
+ """
+ Returns a dummy Params instance used as a placeholder to generate docs.
+ """
+ dummy = Params()
+ dummy.uid = "undefined"
+ return dummy
diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py
new file mode 100644
index 0000000000..5eb81106f1
--- /dev/null
+++ b/python/pyspark/ml/param/_gen_shared_params.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+header = """#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#"""
+
+
+def _gen_param_code(name, doc, defaultValue):
+ """
+ Generates Python code for a shared param class.
+
+ :param name: param name
+ :param doc: param doc
+ :param defaultValue: string representation of the param
+ :return: code string
+ """
+ # TODO: How to correctly inherit instance attributes?
+ template = '''class Has$Name(Params):
+ """
+ Params with $name.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ $name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
+
+ def __init__(self):
+ super(Has$Name, self).__init__()
+ #: param for $doc
+ self.$name = Param(self, "$name", "$doc", $defaultValue)
+
+ def set$Name(self, value):
+ """
+ Sets the value of :py:attr:`$name`.
+ """
+ self.paramMap[self.$name] = value
+ return self
+
+ def get$Name(self):
+ """
+ Gets the value of $name or its default value.
+ """
+ if self.$name in self.paramMap:
+ return self.paramMap[self.$name]
+ else:
+ return self.$name.defaultValue'''
+
+ upperCamelName = name[0].upper() + name[1:]
+ return template \
+ .replace("$name", name) \
+ .replace("$Name", upperCamelName) \
+ .replace("$doc", doc) \
+ .replace("$defaultValue", defaultValue)
+
+if __name__ == "__main__":
+ print header
+ print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
+ print "from pyspark.ml.param import Param, Params\n\n"
+ shared = [
+ ("maxIter", "max number of iterations", "100"),
+ ("regParam", "regularization constant", "0.1"),
+ ("featuresCol", "features column name", "'features'"),
+ ("labelCol", "label column name", "'label'"),
+ ("predictionCol", "prediction column name", "'prediction'"),
+ ("inputCol", "input column name", "'input'"),
+ ("outputCol", "output column name", "'output'"),
+ ("numFeatures", "number of features", "1 << 18")]
+ code = []
+ for name, doc, defaultValue in shared:
+ code.append(_gen_param_code(name, doc, defaultValue))
+ print "\n\n\n".join(code)
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
new file mode 100644
index 0000000000..586822f2de
--- /dev/null
+++ b/python/pyspark/ml/param/shared.py
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
+
+from pyspark.ml.param import Param, Params
+
+
+class HasMaxIter(Params):
+ """
+ Params with maxIter.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100)
+
+ def __init__(self):
+ super(HasMaxIter, self).__init__()
+ #: param for max number of iterations
+ self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
+
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ self.paramMap[self.maxIter] = value
+ return self
+
+ def getMaxIter(self):
+ """
+ Gets the value of maxIter or its default value.
+ """
+ if self.maxIter in self.paramMap:
+ return self.paramMap[self.maxIter]
+ else:
+ return self.maxIter.defaultValue
+
+
+class HasRegParam(Params):
+ """
+ Params with regParam.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1)
+
+ def __init__(self):
+ super(HasRegParam, self).__init__()
+ #: param for regularization constant
+ self.regParam = Param(self, "regParam", "regularization constant", 0.1)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ self.paramMap[self.regParam] = value
+ return self
+
+ def getRegParam(self):
+ """
+ Gets the value of regParam or its default value.
+ """
+ if self.regParam in self.paramMap:
+ return self.paramMap[self.regParam]
+ else:
+ return self.regParam.defaultValue
+
+
+class HasFeaturesCol(Params):
+ """
+ Params with featuresCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features')
+
+ def __init__(self):
+ super(HasFeaturesCol, self).__init__()
+ #: param for features column name
+ self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
+
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ self.paramMap[self.featuresCol] = value
+ return self
+
+ def getFeaturesCol(self):
+ """
+ Gets the value of featuresCol or its default value.
+ """
+ if self.featuresCol in self.paramMap:
+ return self.paramMap[self.featuresCol]
+ else:
+ return self.featuresCol.defaultValue
+
+
+class HasLabelCol(Params):
+ """
+ Params with labelCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label')
+
+ def __init__(self):
+ super(HasLabelCol, self).__init__()
+ #: param for label column name
+ self.labelCol = Param(self, "labelCol", "label column name", 'label')
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ self.paramMap[self.labelCol] = value
+ return self
+
+ def getLabelCol(self):
+ """
+ Gets the value of labelCol or its default value.
+ """
+ if self.labelCol in self.paramMap:
+ return self.paramMap[self.labelCol]
+ else:
+ return self.labelCol.defaultValue
+
+
+class HasPredictionCol(Params):
+ """
+ Params with predictionCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction')
+
+ def __init__(self):
+ super(HasPredictionCol, self).__init__()
+ #: param for prediction column name
+ self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ self.paramMap[self.predictionCol] = value
+ return self
+
+ def getPredictionCol(self):
+ """
+ Gets the value of predictionCol or its default value.
+ """
+ if self.predictionCol in self.paramMap:
+ return self.paramMap[self.predictionCol]
+ else:
+ return self.predictionCol.defaultValue
+
+
+class HasInputCol(Params):
+ """
+ Params with inputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input')
+
+ def __init__(self):
+ super(HasInputCol, self).__init__()
+ #: param for input column name
+ self.inputCol = Param(self, "inputCol", "input column name", 'input')
+
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ self.paramMap[self.inputCol] = value
+ return self
+
+ def getInputCol(self):
+ """
+ Gets the value of inputCol or its default value.
+ """
+ if self.inputCol in self.paramMap:
+ return self.paramMap[self.inputCol]
+ else:
+ return self.inputCol.defaultValue
+
+
+class HasOutputCol(Params):
+ """
+ Params with outputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output')
+
+ def __init__(self):
+ super(HasOutputCol, self).__init__()
+ #: param for output column name
+ self.outputCol = Param(self, "outputCol", "output column name", 'output')
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ self.paramMap[self.outputCol] = value
+ return self
+
+ def getOutputCol(self):
+ """
+ Gets the value of outputCol or its default value.
+ """
+ if self.outputCol in self.paramMap:
+ return self.paramMap[self.outputCol]
+ else:
+ return self.outputCol.defaultValue
+
+
+class HasNumFeatures(Params):
+ """
+ Params with numFeatures.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
+
+ def __init__(self):
+ super(HasNumFeatures, self).__init__()
+ #: param for number of features
+ self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
+
+ def setNumFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numFeatures`.
+ """
+ self.paramMap[self.numFeatures] = value
+ return self
+
+ def getNumFeatures(self):
+ """
+ Gets the value of numFeatures or its default value.
+ """
+ if self.numFeatures in self.paramMap:
+ return self.paramMap[self.numFeatures]
+ else:
+ return self.numFeatures.defaultValue