aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-13 16:45:59 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-13 16:45:59 -0800
commit4f4c6d5a5db04a56906bacdc85d7e5589b6edada (patch)
treede4ce58bb370e0cd659c8e6624651b986aca5f26 /mllib
parentd50a91d529b0913364b483c511397d4af308a435 (diff)
downloadspark-4f4c6d5a5db04a56906bacdc85d7e5589b6edada.tar.gz
spark-4f4c6d5a5db04a56906bacdc85d7e5589b6edada.tar.bz2
spark-4f4c6d5a5db04a56906bacdc85d7e5589b6edada.zip
[SPARK-5730][ML] add doc groups to spark.ml components
This PR adds three groups to the ScalaDoc: `param`, `setParam`, and `getParam`. Params will show up in the generated Scala API doc as the top group. Setters/getters will be at the bottom. Preview: ![screen shot 2015-02-13 at 2 47 49 pm](https://cloud.githubusercontent.com/assets/829644/6196657/5740c240-b38f-11e4-94bb-bd8ef5a796c5.png) Author: Xiangrui Meng <meng@databricks.com> Closes #4600 from mengxr/SPARK-5730 and squashes the following commits: febed9a [Xiangrui Meng] add doc groups to spark.ml components
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala70
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala90
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala35
13 files changed, 235 insertions, 26 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index cd95c16aa7..2ec2ccdb8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -62,7 +62,10 @@ abstract class Transformer extends PipelineStage with Params {
private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
+ /** @group setParam */
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+
+ /** @group setParam */
def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 1bf8eb4640..124ab30f27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -66,6 +66,7 @@ private[spark] abstract class Classifier[
extends Predictor[FeaturesType, E, M]
with ClassifierParams {
+ /** @group setParam */
def setRawPredictionCol(value: String): E =
set(rawPredictionCol, value).asInstanceOf[E]
@@ -87,6 +88,7 @@ private[spark]
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {
+ /** @group setParam */
def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
/** Number of classes (values which the label can take). */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c146fe244c..a9a5af5f0f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -49,8 +49,13 @@ class LogisticRegression
setMaxIter(100)
setThreshold(0.5)
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
@@ -93,6 +98,7 @@ class LogisticRegressionModel private[ml] (
setThreshold(0.5)
+ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
private val margin: Vector => Double = (features) => {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1202528ca6..38518785dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -61,6 +61,7 @@ private[spark] abstract class ProbabilisticClassifier[
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
+ /** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
}
@@ -82,6 +83,7 @@ private[spark] abstract class ProbabilisticClassificationModel[
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
+ /** @group setParam */
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index f21a30627e..2360f4479f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -35,13 +35,23 @@ import org.apache.spark.sql.types.DoubleType
class BinaryClassificationEvaluator extends Evaluator with Params
with HasRawPredictionCol with HasLabelCol {
- /** param for metric name in evaluation */
+ /**
+ * param for metric name in evaluation
+ * @group param
+ */
val metricName: Param[String] = new Param(this, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+
+ /** @group getParam */
def getMetricName: String = get(metricName)
+
+ /** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
+ /** @group setParam */
def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
+
+ /** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 0956062643..6131ba8832 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -31,11 +31,18 @@ import org.apache.spark.sql.types.DataType
@AlphaComponent
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
- /** number of features */
+ /**
+ * number of features
+ * @group param
+ */
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
- def setNumFeatures(value: Int) = set(numFeatures, value)
+
+ /** @group getParam */
def getNumFeatures: Int = get(numFeatures)
+ /** @group setParam */
+ def setNumFeatures(value: Int) = set(numFeatures, value)
+
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 4745a7ae95..7623ec59ae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -39,7 +39,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
@AlphaComponent
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
@@ -75,7 +78,10 @@ class StandardScalerModel private[ml] (
scaler: feature.StandardScalerModel)
extends Model[StandardScalerModel] with StandardScalerParams {
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 89b53f3890..e416c1eb58 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -85,8 +85,13 @@ private[spark] abstract class Predictor[
M <: PredictionModel[FeaturesType, M]]
extends Estimator[M] with PredictorParams {
+ /** @group setParam */
def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
@@ -160,8 +165,10 @@ private[spark] abstract class Predictor[
private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M] with PredictorParams {
+ /** @group setParam */
def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+ /** @group setParam */
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index 51cd48c904..b45bd1499b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -20,5 +20,19 @@ package org.apache.spark
/**
* Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
* assemble and configure practical machine learning pipelines.
+ *
+ * @groupname param Parameters
+ * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get
+ * the parameter values through setters and getters, respectively.
+ * @groupprio param -5
+ *
+ * @groupname setParam Parameter setters
+ * @groupprio setParam 5
+ *
+ * @groupname getParam Parameter getters
+ * @groupprio getParam 6
+ *
+ * @groupname Ungrouped Members
+ * @groupprio Ungrouped 0
*/
package object ml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 32fc74462e..1a70322b4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -24,67 +24,117 @@ package org.apache.spark.ml.param
*/
private[ml] trait HasRegParam extends Params {
- /** param for regularization parameter */
+ /**
+ * param for regularization parameter
+ * @group param
+ */
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+
+ /** @group getParam */
def getRegParam: Double = get(regParam)
}
private[ml] trait HasMaxIter extends Params {
- /** param for max number of iterations */
+ /**
+ * param for max number of iterations
+ * @group param
+ */
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+
+ /** @group getParam */
def getMaxIter: Int = get(maxIter)
}
private[ml] trait HasFeaturesCol extends Params {
- /** param for features column name */
+ /**
+ * param for features column name
+ * @group param
+ */
val featuresCol: Param[String] =
new Param(this, "featuresCol", "features column name", Some("features"))
+
+ /** @group getParam */
def getFeaturesCol: String = get(featuresCol)
}
private[ml] trait HasLabelCol extends Params {
- /** param for label column name */
+ /**
+ * param for label column name
+ * @group param
+ */
val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
+
+ /** @group getParam */
def getLabelCol: String = get(labelCol)
}
private[ml] trait HasPredictionCol extends Params {
- /** param for prediction column name */
+ /**
+ * param for prediction column name
+ * @group param
+ */
val predictionCol: Param[String] =
new Param(this, "predictionCol", "prediction column name", Some("prediction"))
+
+ /** @group getParam */
def getPredictionCol: String = get(predictionCol)
}
private[ml] trait HasRawPredictionCol extends Params {
- /** param for raw prediction column name */
+ /**
+ * param for raw prediction column name
+ * @group param
+ */
val rawPredictionCol: Param[String] =
new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
Some("rawPrediction"))
+
+ /** @group getParam */
def getRawPredictionCol: String = get(rawPredictionCol)
}
private[ml] trait HasProbabilityCol extends Params {
- /** param for predicted class conditional probabilities column name */
+ /**
+ * param for predicted class conditional probabilities column name
+ * @group param
+ */
val probabilityCol: Param[String] =
new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
Some("probability"))
+
+ /** @group getParam */
def getProbabilityCol: String = get(probabilityCol)
}
private[ml] trait HasThreshold extends Params {
- /** param for threshold in (binary) prediction */
+ /**
+ * param for threshold in (binary) prediction
+ * @group param
+ */
val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
+
+ /** @group getParam */
def getThreshold: Double = get(threshold)
}
private[ml] trait HasInputCol extends Params {
- /** param for input column name */
+ /**
+ * param for input column name
+ * @group param
+ */
val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
+
+ /** @group getParam */
def getInputCol: String = get(inputCol)
}
private[ml] trait HasOutputCol extends Params {
- /** param for output column name */
+ /**
+ * param for output column name
+ * @group param
+ */
val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
+
+ /** @group getParam */
def getOutputCol: String = get(outputCol)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index bf5737177c..aac487745f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -49,43 +49,89 @@ import org.apache.spark.util.random.XORShiftRandom
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
with HasPredictionCol {
- /** Param for rank of the matrix factorization. */
+ /**
+ * Param for rank of the matrix factorization.
+ * @group param
+ */
val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+
+ /** @group getParam */
def getRank: Int = get(rank)
- /** Param for number of user blocks. */
+ /**
+ * Param for number of user blocks.
+ * @group param
+ */
val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+
+ /** @group getParam */
def getNumUserBlocks: Int = get(numUserBlocks)
- /** Param for number of item blocks. */
+ /**
+ * Param for number of item blocks.
+ * @group param
+ */
val numItemBlocks =
new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+
+ /** @group getParam */
def getNumItemBlocks: Int = get(numItemBlocks)
- /** Param to decide whether to use implicit preference. */
+ /**
+ * Param to decide whether to use implicit preference.
+ * @group param
+ */
val implicitPrefs =
new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+
+ /** @group getParam */
def getImplicitPrefs: Boolean = get(implicitPrefs)
- /** Param for the alpha parameter in the implicit preference formulation. */
+ /**
+ * Param for the alpha parameter in the implicit preference formulation.
+ * @group param
+ */
val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+
+ /** @group getParam */
def getAlpha: Double = get(alpha)
- /** Param for the column name for user ids. */
+ /**
+ * Param for the column name for user ids.
+ * @group param
+ */
val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+
+ /** @group getParam */
def getUserCol: String = get(userCol)
- /** Param for the column name for item ids. */
+ /**
+ * Param for the column name for item ids.
+ * @group param
+ */
val itemCol =
new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+
+ /** @group getParam */
def getItemCol: String = get(itemCol)
- /** Param for the column name for ratings. */
+ /**
+ * Param for the column name for ratings.
+ * @group param
+ */
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+
+ /** @group getParam */
def getRatingCol: String = get(ratingCol)
+ /**
+ * Param for whether to apply nonnegativity constraints.
+ * @group param
+ */
val nonnegative = new BooleanParam(
this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false))
+
+ /** @group getParam */
val getNonnegative: Boolean = get(nonnegative)
/**
@@ -181,20 +227,46 @@ class ALS extends Estimator[ALSModel] with ALSParams {
import org.apache.spark.ml.recommendation.ALS.Rating
+ /** @group setParam */
def setRank(value: Int): this.type = set(rank, value)
+
+ /** @group setParam */
def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+
+ /** @group setParam */
def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+
+ /** @group setParam */
def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
+
+ /** @group setParam */
def setAlpha(value: Double): this.type = set(alpha, value)
+
+ /** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
def setItemCol(value: String): this.type = set(itemCol, value)
+
+ /** @group setParam */
def setRatingCol(value: String): this.type = set(ratingCol, value)
+
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
- /** Sets both numUserBlocks and numItemBlocks to the specific value. */
+ /**
+ * Sets both numUserBlocks and numItemBlocks to the specific value.
+ * @group setParam
+ */
def setNumBlocks(value: Int): this.type = {
setNumUserBlocks(value)
setNumItemBlocks(value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index d5a7bdafcb..65f6627a0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -44,7 +44,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
setRegParam(0.1)
setMaxIter(100)
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 324b1ba784..b139bc8dcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -31,22 +31,42 @@ import org.apache.spark.sql.types.StructType
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
private[ml] trait CrossValidatorParams extends Params {
- /** param for the estimator to be cross-validated */
+ /**
+ * param for the estimator to be cross-validated
+ * @group param
+ */
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+
+ /** @group getParam */
def getEstimator: Estimator[_] = get(estimator)
- /** param for estimator param maps */
+ /**
+ * param for estimator param maps
+ * @group param
+ */
val estimatorParamMaps: Param[Array[ParamMap]] =
new Param(this, "estimatorParamMaps", "param maps for the estimator")
+
+ /** @group getParam */
def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
- /** param for the evaluator for selection */
+ /**
+ * param for the evaluator for selection
+ * @group param
+ */
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+
+ /** @group getParam */
def getEvaluator: Evaluator = get(evaluator)
- /** param for number of folds for cross validation */
+ /**
+ * param for number of folds for cross validation
+ * @group param
+ */
val numFolds: IntParam =
new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+
+ /** @group getParam */
def getNumFolds: Int = get(numFolds)
}
@@ -59,9 +79,16 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
private val f2jBLAS = new F2jBLAS
+ /** @group setParam */
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+
+ /** @group setParam */
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+
+ /** @group setParam */
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+
+ /** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {