aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-17 11:24:38 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-17 11:24:38 -0700
commit64743870f23bffb8d96dcc8a0181c1452782a151 (patch)
treebacc0b6cc3d27e870d4c7cb58dc7c56f28be52a3 /mllib
parentaad644fbe29151aec9004817d42e4928bdb326f3 (diff)
downloadspark-64743870f23bffb8d96dcc8a0181c1452782a151.tar.gz
spark-64743870f23bffb8d96dcc8a0181c1452782a151.tar.bz2
spark-64743870f23bffb8d96dcc8a0181c1452782a151.zip
[SPARK-10394] [ML] Make GBTParams use shared stepSize
```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8552 from yanboliang/spark-10394.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala28
1 files changed, 13 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index d29f5253c9..42e74ce6d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree
import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds}
+import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
@@ -365,17 +365,7 @@ private[ml] object RandomForestParams {
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
-
- /**
- * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
- * estimator.
- * (default = 0.1)
- * @group param
- */
- final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
- " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
- ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
/* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used.
@@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
- /** @group setParam */
+ /**
+ * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+ * estimator.
+ * (default = 0.1)
+ * @group setParam
+ */
def setStepSize(value: Double): this.type = set(stepSize, value)
- /** @group getParam */
- final def getStepSize: Double = $(stepSize)
+ override def validateParams(): Unit = {
+ require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
+ getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
+ s"but it given invalid value $getStepSize.")
+ }
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(