aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param/shared.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-25 02:26:33 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-08-25 02:26:33 -0700
commit6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328 (patch)
tree1e08638b25aedd2adcd9b84612a87cc095508a13 /python/pyspark/ml/param/shared.py
parente0b20f9f24d5c3304bf517a4dcfb0da93be5bc75 (diff)
downloadspark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.tar.gz
spark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.tar.bz2
spark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.zip
[SPARK-17197][ML][PYSPARK] PySpark LiR/LoR supports tree aggregation level configurable.
## What changes were proposed in this pull request? [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090) makes tree aggregation level in LiR/LoR configurable, this PR makes PySpark support this function. ## How was this patch tested? Since ```aggregationDepth``` is an expert param, I'm not prefer to test it in doctest which is also used for example. Here is the offline test result: ![image](https://cloud.githubusercontent.com/assets/1962026/17879457/f83d7760-68a6-11e6-9936-d0a884d5d6ec.png) Author: Yanbo Liang <ybliang8@gmail.com> Closes #14766 from yanboliang/spark-17197.
Diffstat (limited to 'python/pyspark/ml/param/shared.py')
-rw-r--r--python/pyspark/ml/param/shared.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index c5ccf81540..24af07afc7 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -560,6 +560,30 @@ class HasVarianceCol(Params):
return self.getOrDefault(self.varianceCol)
+class HasAggregationDepth(Params):
+ """
+ Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
+ """
+
+ aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
+
+ def __init__(self):
+ super(HasAggregationDepth, self).__init__()
+ self._setDefault(aggregationDepth=2)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
+ def getAggregationDepth(self):
+ """
+ Gets the value of aggregationDepth or its default value.
+ """
+ return self.getOrDefault(self.aggregationDepth)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.