aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py2
-rw-r--r--python/pyspark/ml/param/shared.py29
2 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index ee901f2584..ed3171b697 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -97,6 +97,8 @@ if __name__ == "__main__":
("inputCol", "input column name", None),
("inputCols", "input column names", None),
("outputCol", "output column name", None),
+ ("numFeatures", "number of features", None),
+ ("checkpointInterval", "checkpoint interval (>= 1)", None),
("seed", "random seed", None),
("tol", "the convergence tolerance for iterative algorithms", None),
("stepSize", "Step size to be used for each iteration of optimization.", None)]
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 5e7529c1dc..d0bcadee22 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -310,6 +310,35 @@ class HasNumFeatures(Params):
return self.getOrDefault(self.numFeatures)
+class HasCheckpointInterval(Params):
+ """
+ Mixin for param checkpointInterval: checkpoint interval (>= 1).
+ """
+
+ # a placeholder to make it appear in the generated doc
+ checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1)")
+
+ def __init__(self):
+ super(HasCheckpointInterval, self).__init__()
+ #: param for checkpoint interval (>= 1)
+ self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)")
+ if None is not None:
+ self._setDefault(checkpointInterval=None)
+
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ self.paramMap[self.checkpointInterval] = value
+ return self
+
+ def getCheckpointInterval(self):
+ """
+ Gets the value of checkpointInterval or its default value.
+ """
+ return self.getOrDefault(self.checkpointInterval)
+
+
class HasSeed(Params):
"""
Mixin for param seed: random seed.