aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-02-26 08:30:32 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-26 08:30:32 -0800
commitb33261f91387904c5aaccae40f86922c92a4e09a (patch)
treeabae986f0bd829276d4b320f8242275a22609212 /mllib
parent99dfcedbfd4c83c7b6a343456f03e8c6e29968c5 (diff)
downloadspark-b33261f91387904c5aaccae40f86922c92a4e09a.tar.gz
spark-b33261f91387904c5aaccae40f86922c92a4e09a.tar.bz2
spark-b33261f91387904c5aaccae40f86922c92a4e09a.zip
[SPARK-12634][PYSPARK][DOC] PySpark tree parameter desc to consistent format
Part of task for [SPARK-11219](https://issues.apache.org/jira/browse/SPARK-11219) to make PySpark MLlib parameter description formatting consistent. This is for the tree module. closes #10601 Author: Bryan Cutler <cutlerb@gmail.com> Author: vijaykiran <mail@vijaykiran.com> Closes #11353 from BryanCutler/param-desc-consistent-tree-SPARK-12634.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala172
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala69
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala13
4 files changed, 140 insertions, 132 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 51235a2371..40440d50fc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -38,8 +38,9 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* A class which implements a decision tree learning algorithm for classification and regression.
* It supports both continuous and categorical features.
+ *
* @param strategy The configuration parameters for the tree algorithm which specify the type
- * of algorithm (classification, regression, etc.), feature type (continuous,
+ * of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
@Since("1.0.0")
@@ -50,8 +51,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
/**
* Method to train a decision tree model over an RDD
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return DecisionTreeModel that can be used for prediction
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
@@ -77,9 +78,9 @@ object DecisionTree extends Serializable with Logging {
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param strategy The configuration parameters for the tree algorithm which specify the type
- * of algorithm (classification, regression, etc.), feature type (continuous,
+ * of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
- * @return DecisionTreeModel that can be used for prediction
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.0.0")
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
@@ -97,11 +98,11 @@ object DecisionTree extends Serializable with Logging {
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
- * @param algo algorithm, classification or regression
- * @param impurity impurity criterion used for information gain calculation
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @return DecisionTreeModel that can be used for prediction
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.0.0")
def train(
@@ -124,12 +125,12 @@ object DecisionTree extends Serializable with Logging {
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
- * @param algo algorithm, classification or regression
- * @param impurity impurity criterion used for information gain calculation
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClasses number of classes for classification. Default value of 2.
- * @return DecisionTreeModel that can be used for prediction
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @param numClasses Number of classes for classification. Default value of 2.
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.2.0")
def train(
@@ -153,17 +154,17 @@ object DecisionTree extends Serializable with Logging {
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
- * @param algo classification or regression
- * @param impurity criterion used for information gain calculation
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClasses number of classes for classification. Default value of 2.
- * @param maxBins maximum number of bins used for splitting features
- * @param quantileCalculationStrategy algorithm for calculating quantiles
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
- * E.g., an entry (n -> k) indicates that feature n is categorical
- * with k categories indexed from 0: {0, 1, ..., k-1}.
- * @return DecisionTreeModel that can be used for prediction
+ * @param algo Type of decision tree, either classification or regression.
+ * @param impurity Criterion used for information gain calculation.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * @param numClasses Number of classes for classification. Default value of 2.
+ * @param maxBins Maximum number of bins used for splitting features.
+ * @param quantileCalculationStrategy Algorithm for calculating quantiles.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.0.0")
def train(
@@ -185,18 +186,18 @@ object DecisionTree extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
- * @param numClasses number of classes for classification.
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
- * E.g., an entry (n -> k) indicates that feature n is categorical
- * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param numClasses Number of classes for classification.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 5)
- * @param maxBins maximum number of bins used for splitting features
- * (suggested value: 32)
- * @return DecisionTreeModel that can be used for prediction
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 5)
+ * @param maxBins Maximum number of bins used for splitting features.
+ * (suggested value: 32)
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.1.0")
def trainClassifier(
@@ -232,17 +233,17 @@ object DecisionTree extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels are real numbers.
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
- * E.g., an entry (n -> k) indicates that feature n is categorical
- * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
- * Supported values: "variance".
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 5)
- * @param maxBins maximum number of bins used for splitting features
- * (suggested value: 32)
- * @return DecisionTreeModel that can be used for prediction
+ * The only supported value for regression is "variance".
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 5)
+ * @param maxBins Maximum number of bins used for splitting features.
+ * (suggested value: 32)
+ * @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.1.0")
def trainRegressor(
@@ -277,7 +278,7 @@ object DecisionTree extends Serializable with Logging {
*
* @param node Node in tree from which to classify the given data point.
* @param binnedFeatures Binned feature vector for data point.
- * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param bins Possible bins for all features, indexed (numFeatures)(numBins).
* @param unorderedFeatures Set of indices of unordered features.
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
@@ -333,12 +334,12 @@ object DecisionTree extends Serializable with Logging {
* For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param splits possible splits indexed (numFeatures)(numSplits)
- * @param unorderedFeatures Set of indices of unordered features.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits).
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
@@ -394,10 +395,10 @@ object DecisionTree extends Serializable with Logging {
*
* For each feature, the sufficient statistics of one bin are updated.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
@@ -430,17 +431,17 @@ object DecisionTree extends Serializable with Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
- * @param metadata Learning and dataset metadata
+ * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]].
+ * @param metadata Learning and dataset metadata.
* @param topNodes Root node for each tree. Used for matching instances with nodes.
- * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree.
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
- * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
- * @param bins possible bins for all features, indexed (numFeatures)(numBins)
- * @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
- * Updated with new non-leaf nodes which are created.
+ * @param splits Possible splits for all features, indexed (numFeatures)(numSplits).
+ * @param bins Possible bins for all features, indexed (numFeatures)(numBins).
+ * @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
@@ -527,10 +528,10 @@ object DecisionTree extends Serializable with Logging {
* Each data point contributes to one node. For each feature,
* the aggregate sufficient statistics are updated for the relevant bins.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
- * @param baggedPoint Data point being aggregated.
- * @return agg
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return Array of decision tree statistics.
*/
def binSeqOp(
agg: Array[DTStatsAggregator],
@@ -563,6 +564,7 @@ object DecisionTree extends Serializable with Logging {
/**
* Get node index in group --> features indices map,
* which is a short cut to find feature indices for a node given node index in group
+ *
* @param treeToNodeToIndexInfo
* @return
*/
@@ -719,9 +721,10 @@ object DecisionTree extends Serializable with Logging {
/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
- * @param leftImpurityCalculator left node aggregates for this (feature, split)
- * @param rightImpurityCalculator right node aggregate for this (feature, split)
- * @return information gain and statistics for split
+ *
+ * @param leftImpurityCalculator Left node aggregates for this (feature, split).
+ * @param rightImpurityCalculator Right node aggregate for this (feature, split).
+ * @return Information gain and statistics for split.
*/
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
@@ -771,9 +774,10 @@ object DecisionTree extends Serializable with Logging {
/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
- * @param leftImpurityCalculator left node aggregates for a split
- * @param rightImpurityCalculator right node aggregates for a split
- * @return predict value and impurity for current node
+ *
+ * @param leftImpurityCalculator Left node aggregates for a split.
+ * @param rightImpurityCalculator Right node aggregates for a split.
+ * @return Predict value and impurity for current node.
*/
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
@@ -788,8 +792,9 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
+ *
* @param binAggregates Bin statistics.
- * @return tuple for best split: (Split, information gain, prediction at node)
+ * @return Tuple for best split: (Split, information gain, prediction at node).
*/
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
@@ -955,8 +960,8 @@ object DecisionTree extends Serializable with Logging {
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @param metadata Learning and dataset metadata
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param metadata Learning and dataset metadata.
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numSplits).
@@ -1102,12 +1107,13 @@ object DecisionTree extends Serializable with Logging {
* NOTE: Returned number of splits is set based on `featureSamples` and
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
- * @param featureSamples feature values of each sample
- * @param metadata decision tree metadata
+ *
+ * @param featureSamples Feature values of each sample.
+ * @param metadata Decision tree metadata.
* NOTE: `metadata.numbins` will be changed accordingly
- * if there are not enough splits to be found
- * @param featureIndex feature index to find splits
- * @return array of splits
+ * if there are not enough splits to be found.
+ * @param featureIndex Feature index to find splits.
+ * @return Array of splits.
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Array[Double],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 1b71256c58..d131f5da6c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -54,8 +54,9 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
/**
* Method to train a gradient boosting model
+ *
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return a gradient boosted trees model that can be used for prediction
+ * @return GradientBoostedTreesModel that can be used for prediction.
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
@@ -82,13 +83,14 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
/**
* Method to validate a gradient boosting model
+ *
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param validationInput Validation dataset.
* This dataset should be different from the training dataset,
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
- * @return a gradient boosted trees model that can be used for prediction
+ * @return GradientBoostedTreesModel that can be used for prediction.
*/
@Since("1.4.0")
def runWithValidation(
@@ -132,7 +134,7 @@ object GradientBoostedTrees extends Logging {
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
- * @return a gradient boosted trees model that can be used for prediction
+ * @return GradientBoostedTreesModel that can be used for prediction.
*/
@Since("1.2.0")
def train(
@@ -153,11 +155,11 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
- * @param input training dataset
- * @param validationInput validation dataset, ignored if validate is set to false.
- * @param boostingStrategy boosting parameters
- * @param validate whether or not to use the validation dataset.
- * @return a gradient boosted trees model that can be used for prediction
+ * @param input Training dataset.
+ * @param validationInput Validation dataset, ignored if validate is set to false.
+ * @param boostingStrategy Boosting parameters.
+ * @param validate Whether or not to use the validation dataset.
+ * @return GradientBoostedTreesModel that can be used for prediction.
*/
private def boost(
input: RDD[LabeledPoint],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 570a76f960..b7714b382a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -53,12 +53,12 @@ import org.apache.spark.util.random.SamplingUtils
* random forests]]
*
* @param strategy The configuration parameters for the random forest algorithm which specify
- * the type of algorithm (classification, regression, etc.), feature type
+ * the type of random forest (classification or regression), feature type
* (continuous, categorical), depth of the tree, quantile calculation strategy,
* etc.
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported values: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
@@ -121,8 +121,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return a random forest model that can be used for prediction
+ *
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
@@ -269,12 +270,12 @@ object RandomForest extends Serializable with Logging {
* @param strategy Parameters for training each tree in the forest.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported values: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt".
- * @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return a random forest model that can be used for prediction
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction.
*/
@Since("1.2.0")
def trainClassifier(
@@ -294,25 +295,25 @@ object RandomForest extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
- * @param numClasses number of classes for classification.
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
- * E.g., an entry (n -> k) indicates that feature n is categorical
- * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param numClasses Number of classes for classification.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported values: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt".
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
- * @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
- * @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return a random forest model that can be used for prediction
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 4)
+ * @param maxBins Maximum number of bins used for splitting features
+ * (suggested value: 100)
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction.
*/
@Since("1.2.0")
def trainClassifier(
@@ -358,12 +359,12 @@ object RandomForest extends Serializable with Logging {
* @param strategy Parameters for training each tree in the forest.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported values: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "onethird".
- * @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return a random forest model that can be used for prediction
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction.
*/
@Since("1.2.0")
def trainRegressor(
@@ -383,24 +384,24 @@ object RandomForest extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels are real numbers.
- * @param categoricalFeaturesInfo Map storing arity of categorical features.
- * E.g., an entry (n -> k) indicates that feature n is categorical
- * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported values: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "onethird".
* @param impurity Criterion used for information gain calculation.
- * Supported values: "variance".
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
- * @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
- * @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return a random forest model that can be used for prediction
+ * The only supported value for regression is "variance".
+ * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
+ * (suggested value: 4)
+ * @param maxBins Maximum number of bins used for splitting features.
+ * (suggested value: 100)
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction.
*/
@Since("1.2.0")
def trainRegressor(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 6c04403f1a..9e3e50192d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -34,8 +34,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
* [[org.apache.spark.mllib.tree.impurity.Entropy]].
* Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
+ * 1 internal node + 2 leaf nodes).
* @param numClasses Number of classes for classification.
* (Ignored for regression.)
* Default value is 2 (binary classification).
@@ -45,10 +45,9 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
* [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
- * number of discrete values they take. For example, an entry (n ->
- * k) implies the feature n is categorical with k categories 0,
- * 1, 2, ... , k-1. It's important to note that features are
- * zero-indexed.
+ * number of discrete values they take. An entry (n -> k)
+ * indicates that feature n is categorical with k categories
+ * indexed from 0: {0, 1, ..., k-1}.
* @param minInstancesPerNode Minimum number of instances each child must have after split.
* Default value is 1. If a split cause left or right child
* to have less than minInstancesPerNode,
@@ -60,7 +59,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* 256 MB.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
- * maintain a separate RDD of node Id cache for each row.
+ * maintain a separate RDD of node Id cache for each row.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in