diff options
author | Davies Liu <davies@databricks.com> | 2014-11-20 15:31:28 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-20 15:31:28 -0800 |
commit | 1c53a5db993193122bfa79574d2540149fe2cc08 (patch) | |
tree | 68efb0022ee91fc29450473eede8ac1ce9bccb39 /mllib/src | |
parent | b8e6886fb8ff8f667fb7e600cd727d8649cad1d1 (diff) | |
download | spark-1c53a5db993193122bfa79574d2540149fe2cc08.tar.gz spark-1c53a5db993193122bfa79574d2540149fe2cc08.tar.bz2 spark-1c53a5db993193122bfa79574d2540149fe2cc08.zip |
[SPARK-4439] [MLlib] add python api for random forest
```
class RandomForestModel
| A model trained by RandomForest
|
| numTrees(self)
| Get number of trees in forest.
|
| predict(self, x)
| Predict values for a single data point or an RDD of points using the model trained.
|
| toDebugString(self)
| Full model
|
| totalNumNodes(self)
| Get total number of nodes, summed over all trees in the forest.
|
class RandomForest
| trainClassifier(cls, data, numClassesForClassification, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='gini', maxDepth=4, maxBins=32, seed=None):
| Method to train a decision tree model for binary or multiclass classification.
|
| :param data: Training dataset: RDD of LabeledPoint.
| Labels should take values {0, 1, ..., numClasses-1}.
| :param numClassesForClassification: number of classes for classification.
| :param categoricalFeaturesInfo: Map storing arity of categorical features.
| E.g., an entry (n -> k) indicates that feature n is categorical
| with k categories indexed from 0: {0, 1, ..., k-1}.
| :param numTrees: Number of trees in the random forest.
| :param featureSubsetStrategy: Number of features to consider for splits at each node.
| Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
| If "auto" is set, this parameter is set based on numTrees:
| if numTrees == 1, set to "all";
| if numTrees > 1 (forest) set to "sqrt".
| :param impurity: Criterion used for information gain calculation.
| Supported values: "gini" (recommended) or "entropy".
| :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means
| 1 internal node + 2 leaf nodes. (default: 4)
| :param maxBins: maximum number of bins used for splitting features (default: 100)
| :param seed: Random seed for bootstrapping and choosing feature subsets.
| :return: RandomForestModel that can be used for prediction
|
| trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='variance', maxDepth=4, maxBins=32, seed=None):
| Method to train a decision tree model for regression.
|
| :param data: Training dataset: RDD of LabeledPoint.
| Labels are real numbers.
| :param categoricalFeaturesInfo: Map storing arity of categorical features.
| E.g., an entry (n -> k) indicates that feature n is categorical
| with k categories indexed from 0: {0, 1, ..., k-1}.
| :param numTrees: Number of trees in the random forest.
| :param featureSubsetStrategy: Number of features to consider for splits at each node.
| Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
| If "auto" is set, this parameter is set based on numTrees:
| if numTrees == 1, set to "all";
| if numTrees > 1 (forest) set to "onethird".
| :param impurity: Criterion used for information gain calculation.
| Supported values: "variance".
| :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means
| 1 internal node + 2 leaf nodes.(default: 4)
| :param maxBins: maximum number of bins used for splitting features (default: 100)
| :param seed: Random seed for bootstrapping and choosing feature subsets.
| :return: RandomForestModel that can be used for prediction
|
```
Author: Davies Liu <davies@databricks.com>
Closes #3320 from davies/forest and squashes the following commits:
8003dfc [Davies Liu] reorder
53cf510 [Davies Liu] fix docs
4ca593d [Davies Liu] fix docs
e0df852 [Davies Liu] fix docs
0431746 [Davies Liu] rebased
2b6f239 [Davies Liu] Merge branch 'master' of github.com:apache/spark into forest
885abee [Davies Liu] address comments
dae7fc0 [Davies Liu] address comments
89a000f [Davies Liu] fix docs
565d476 [Davies Liu] add python api for random forest
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 38 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 12 |
2 files changed, 40 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6f94b7f483..b6f7618171 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -40,10 +40,10 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -500,6 +500,40 @@ class PythonMLLibAPI extends Serializable { } /** + * Java stub for Python mllib RandomForest.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainRandomForestModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfo: JMap[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurityStr: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) + if (algo == Algo.Classification) { + RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) + } + } + + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index ca0b6eea9a..3ae6fa2a0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -230,8 +230,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ @@ -261,8 +260,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. @@ -318,8 +316,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ @@ -348,8 +345,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. * Supported values: "variance". * @param maxDepth Maximum depth of the tree. |