aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/shared.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/param/shared.py')
-rw-r--r--python/pyspark/ml/param/shared.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index c5ccf81540..24af07afc7 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -560,6 +560,30 @@ class HasVarianceCol(Params):
return self.getOrDefault(self.varianceCol)
+class HasAggregationDepth(Params):
+ """
+ Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
+ """
+
+ aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
+
+ def __init__(self):
+ super(HasAggregationDepth, self).__init__()
+ self._setDefault(aggregationDepth=2)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
+ def getAggregationDepth(self):
+ """
+ Gets the value of aggregationDepth or its default value.
+ """
+ return self.getOrDefault(self.aggregationDepth)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.