aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala76
1 files changed, 57 insertions, 19 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 4683e6eb96..96fb068e9e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,16 +21,18 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
+import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
- * An example runner for decision tree. Run with
+ * An example runner for decision trees and random forests. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
@@ -57,6 +59,8 @@ object DecisionTreeRunner {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
+ numTrees: Int = 1,
+ featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2)
def main(args: Array[String]) {
@@ -79,11 +83,20 @@ object DecisionTreeRunner {
.action((x, c) => c.copy(maxBins = x))
opt[Int]("minInstancesPerNode")
.text(s"min number of instances required at child nodes to create the parent split," +
- s" default: ${defaultParams.minInstancesPerNode}")
+ s" default: ${defaultParams.minInstancesPerNode}")
.action((x, c) => c.copy(minInstancesPerNode = x))
opt[Double]("minInfoGain")
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
.action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("numTrees")
+ .text(s"number of trees (1 = decision tree, 2+ = random forest)," +
+ s" default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(numTrees = x))
+ opt[String]("featureSubsetStrategy")
+ .text(s"feature subset sampling strategy" +
+ s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " +
+ s"default: ${defaultParams.featureSubsetStrategy}")
+ .action((x, c) => c.copy(featureSubsetStrategy = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
@@ -191,18 +204,35 @@ object DecisionTreeRunner {
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
- val model = DecisionTree.train(training, strategy)
-
- println(model)
-
- if (params.algo == Classification) {
- val accuracy = accuracyScore(model, test)
- println(s"Test accuracy = $accuracy")
- }
-
- if (params.algo == Regression) {
- val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse")
+ if (params.numTrees == 1) {
+ val model = DecisionTree.train(training, strategy)
+ println(model)
+ if (params.algo == Classification) {
+ val accuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $accuracy")
+ }
+ if (params.algo == Regression) {
+ val mse = meanSquaredError(model, test)
+ println(s"Test mean squared error = $mse")
+ }
+ } else {
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == Classification) {
+ val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
+ params.featureSubsetStrategy, randomSeed)
+ println(model)
+ val accuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $accuracy")
+ }
+ if (params.algo == Regression) {
+ val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
+ params.featureSubsetStrategy, randomSeed)
+ println(model)
+ val mse = meanSquaredError(model, test)
+ println(s"Test mean squared error = $mse")
+ }
}
sc.stop()
@@ -211,9 +241,7 @@ object DecisionTreeRunner {
/**
* Calculates the classifier accuracy.
*/
- private def accuracyScore(
- model: DecisionTreeModel,
- data: RDD[LabeledPoint]): Double = {
+ private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
@@ -228,4 +256,14 @@ object DecisionTreeRunner {
err * err
}.mean()
}
+
+ /**
+ * Calculates the mean squared error for regression.
+ */
+ private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = tree.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
}