aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala25
1 files changed, 21 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 64d6af2766..6803772c63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -334,10 +334,10 @@ private[ml] trait HasElasticNetParam extends Params {
private[ml] trait HasTol extends Params {
/**
- * Param for the convergence tolerance for iterative algorithms.
+ * Param for the convergence tolerance for iterative algorithms (>= 0).
* @group param
*/
- final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+ final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getTol: Double = $(tol)
@@ -349,10 +349,10 @@ private[ml] trait HasTol extends Params {
private[ml] trait HasStepSize extends Params {
/**
- * Param for Step size to be used for each iteration of optimization.
+ * Param for Step size to be used for each iteration of optimization (> 0).
* @group param
*/
- final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization")
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", ParamValidators.gt(0))
/** @group getParam */
final def getStepSize: Double = $(stepSize)
@@ -389,4 +389,21 @@ private[ml] trait HasSolver extends Params {
/** @group getParam */
final def getSolver: String = $(solver)
}
+
+/**
+ * Trait for shared param aggregationDepth (default: 2).
+ */
+private[ml] trait HasAggregationDepth extends Params {
+
+ /**
+ * Param for suggested depth for treeAggregate (>= 2).
+ * @group param
+ */
+ final val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2))
+
+ setDefault(aggregationDepth, 2)
+
+ /** @group getParam */
+ final def getAggregationDepth: Int = $(aggregationDepth)
+}
// scalastyle:on