aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala12
2 files changed, 40 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6f94b7f483..b6f7618171 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -40,10 +40,10 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.test.ChiSqTestResult
-import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.impurity._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
+import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -500,6 +500,40 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib RandomForest.train().
+ * This stub returns a handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on exit;
+ * see the Py4J documentation.
+ */
+ def trainRandomForestModel(
+ data: JavaRDD[LabeledPoint],
+ algoStr: String,
+ numClasses: Int,
+ categoricalFeaturesInfo: JMap[Int, Int],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurityStr: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int): RandomForestModel = {
+
+ val algo = Algo.fromString(algoStr)
+ val impurity = Impurities.fromString(impurityStr)
+ val strategy = new Strategy(
+ algo = algo,
+ impurity = impurity,
+ maxDepth = maxDepth,
+ numClassesForClassification = numClasses,
+ maxBins = maxBins,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
+ if (algo == Algo.Classification) {
+ RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
+ } else {
+ RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+ }
+
+ /**
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
* TODO figure out return type.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ca0b6eea9a..3ae6fa2a0e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -230,8 +230,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "sqrt".
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return a random forest model that can be used for prediction
*/
@@ -261,8 +260,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "sqrt".
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
@@ -318,8 +316,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "onethird".
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return a random forest model that can be used for prediction
*/
@@ -348,8 +345,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "onethird".
* @param impurity Criterion used for information gain calculation.
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.