aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/shared.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/shared.py')
-rw-r--r--python/pyspark/ml/param/shared.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 4f243844f8..aaf80f0008 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -223,6 +223,35 @@ class HasInputCol(Params):
return self.getOrDefault(self.inputCol)
+class HasInputCols(Params):
+ """
+ Mixin for param inputCols: input column names.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ inputCols = Param(Params._dummy(), "inputCols", "input column names")
+
+ def __init__(self):
+ super(HasInputCols, self).__init__()
+ #: param for input column names
+ self.inputCols = Param(self, "inputCols", "input column names")
+ if None is not None:
+ self._setDefault(inputCols=None)
+
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ self.paramMap[self.inputCols] = value
+ return self
+
+ def getInputCols(self):
+ """
+ Gets the value of inputCols or its default value.
+ """
+ return self.getOrDefault(self.inputCols)
+
+
class HasOutputCol(Params):
"""
Mixin for param outputCol: output column name.