diff options
Diffstat (limited to 'mllib/src/main')
4 files changed, 140 insertions, 132 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 51235a2371..40440d50fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -38,8 +38,9 @@ import org.apache.spark.util.random.XORShiftRandom /** * A class which implements a decision tree learning algorithm for classification and regression. * It supports both continuous and categorical features. + * * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ @Since("1.0.0") @@ -50,8 +51,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -77,9 +78,9 @@ object DecisionTree extends Serializable with Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -97,11 +98,11 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -124,12 +125,12 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,17 +154,17 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @param maxBins Maximum number of bins used for splitting features. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -185,18 +186,18 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainClassifier( @@ -232,17 +233,17 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainRegressor( @@ -277,7 +278,7 @@ object DecisionTree extends Serializable with Logging { * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param bins Possible bins for all features, indexed (numFeatures)(numBins). * @param unorderedFeatures Set of indices of unordered features. * @return Leaf index if the data point reaches a leaf. * Otherwise, last node reachable in tree matching this example. @@ -333,12 +334,12 @@ object DecisionTree extends Serializable with Logging { * For unordered features, bins correspond to subsets of categories; either the left or right bin * for each subset is updated. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits Possible splits indexed (numFeatures)(numSplits). + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, @@ -394,10 +395,10 @@ object DecisionTree extends Serializable with Logging { * * For each feature, the sufficient statistics of one bin are updated. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, @@ -430,17 +431,17 @@ object DecisionTree extends Serializable with Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]. + * @param metadata Learning and dataset metadata. * @param topNodes Root node for each tree. Used for matching instances with nodes. - * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree. * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, * where nodeIndexInfo stores the index in the group and the * feature subsets (if using feature subsets). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). - * Updated with new non-leaf nodes which are created. + * @param splits Possible splits for all features, indexed (numFeatures)(numSplits). + * @param bins Possible bins for all features, indexed (numFeatures)(numBins). + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where * each value in the array is the data point's node Id * for a corresponding tree. This is used to prevent the need @@ -527,10 +528,10 @@ object DecisionTree extends Serializable with Logging { * Each data point contributes to one node. For each feature, * the aggregate sufficient statistics are updated for the relevant bins. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - * @return agg + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return Array of decision tree statistics. */ def binSeqOp( agg: Array[DTStatsAggregator], @@ -563,6 +564,7 @@ object DecisionTree extends Serializable with Logging { /** * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group + * * @param treeToNodeToIndexInfo * @return */ @@ -719,9 +721,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * + * @param leftImpurityCalculator Left node aggregates for this (feature, split). + * @param rightImpurityCalculator Right node aggregate for this (feature, split). + * @return Information gain and statistics for split. */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, @@ -771,9 +774,10 @@ object DecisionTree extends Serializable with Logging { /** * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node + * + * @param leftImpurityCalculator Left node aggregates for a split. + * @param rightImpurityCalculator Right node aggregates for a split. + * @return Predict value and impurity for current node. */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, @@ -788,8 +792,9 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. + * * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) + * @return Tuple for best split: (Split, information gain, prediction at node). */ private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, @@ -955,8 +960,8 @@ object DecisionTree extends Serializable with Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param metadata Learning and dataset metadata + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param metadata Learning and dataset metadata. * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -1102,12 +1107,13 @@ object DecisionTree extends Serializable with Logging { * NOTE: Returned number of splits is set based on `featureSamples` and * could be different from the specified `numSplits`. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. - * @param featureSamples feature values of each sample - * @param metadata decision tree metadata + * + * @param featureSamples Feature values of each sample. + * @param metadata Decision tree metadata. * NOTE: `metadata.numbins` will be changed accordingly - * if there are not enough splits to be found - * @param featureIndex feature index to find splits - * @return array of splits + * if there are not enough splits to be found. + * @param featureIndex Feature index to find splits. + * @return Array of splits. */ private[tree] def findSplitsForContinuousFeature( featureSamples: Array[Double], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1b71256c58..d131f5da6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -54,8 +54,9 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to train a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -82,13 +83,14 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to validate a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.4.0") def runWithValidation( @@ -132,7 +134,7 @@ object GradientBoostedTrees extends Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,11 +155,11 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param validationInput validation dataset, ignored if validate is set to false. - * @param boostingStrategy boosting parameters - * @param validate whether or not to use the validation dataset. - * @return a gradient boosted trees model that can be used for prediction + * @param input Training dataset. + * @param validationInput Validation dataset, ignored if validate is set to false. + * @param boostingStrategy Boosting parameters. + * @param validate Whether or not to use the validation dataset. + * @return GradientBoostedTreesModel that can be used for prediction. */ private def boost( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 570a76f960..b7714b382a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -53,12 +53,12 @@ import org.apache.spark.util.random.SamplingUtils * random forests]] * * @param strategy The configuration parameters for the random forest algorithm which specify - * the type of algorithm (classification, regression, etc.), feature type + * the type of random forest (classification or regression), feature type * (continuous, categorical), depth of the tree, quantile calculation strategy, * etc. * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and @@ -121,8 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return a random forest model that can be used for prediction + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { @@ -269,12 +270,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -294,25 +295,25 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -358,12 +359,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( @@ -383,24 +384,24 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 6c04403f1a..9e3e50192d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -34,8 +34,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. * (Ignored for regression.) * Default value is 2 (binary classification). @@ -45,10 +45,9 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * number of discrete values they take. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param minInstancesPerNode Minimum number of instances each child must have after split. * Default value is 1. If a split cause left or right child * to have less than minInstancesPerNode, @@ -60,7 +59,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * 256 MB. * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will - * maintain a separate RDD of node Id cache for each row. + * maintain a separate RDD of node Id cache for each row. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in |