aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-25 12:27:19 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-25 12:27:19 -0700
commita7160c4e3aae22600d05e257d0b4d2428754b8ea (patch)
tree55ce372698ad266b258dceb469e8a6a9ae9709b0 /examples/src/main
parenta61d65fc8b97c01be0fa756b52afdc91c46a8561 (diff)
downloadspark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.tar.gz
spark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.tar.bz2
spark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.zip
[SPARK-6113] [ML] Tree ensembles for Pipelines API
This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees. Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions. This PR follows the example set by the previous PR for Decision Trees. It includes a few cleanups to Decision Trees. Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap. I plan to submit a separate PR which makes those values in Model be Options. It does not matter much which PR gets merged first. CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley <joseph@databricks.com> Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits: 729167a [Joseph K. Bradley] small cleanups based on code review bbae2a2 [Joseph K. Bradley] Updated per all comments in code review 855aa9a [Joseph K. Bradley] scala style fix ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well d045ebd [Joseph K. Bradley] some more updates, but far from done ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates.
Diffstat (limited to 'examples/src/main')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala139
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala238
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala248
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala1
4 files changed, 568 insertions, 58 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 2cd515c89d..9002e99d82 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -22,10 +22,9 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
-import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
@@ -64,8 +63,6 @@ object DecisionTreeExample {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
- numTrees: Int = 1,
- featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
@@ -123,8 +120,8 @@ object DecisionTreeExample {
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest > 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
@@ -200,9 +197,18 @@ object DecisionTreeExample {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}
- val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings)
+ val training = dataframes(0).cache()
+ val test = dataframes(1).cache()
- (dataframes(0), dataframes(1))
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
+
+ (training, test)
}
def run(params: Params) {
@@ -217,13 +223,6 @@ object DecisionTreeExample {
val (training: DataFrame, test: DataFrame) =
loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
- val numTraining = training.count()
- val numTest = test.count()
- val numFeatures = training.select("features").first().getAs[Vector](0).size
- println("Loaded data:")
- println(s" numTraining = $numTraining, numTest = $numTest")
- println(s" numFeatures = $numFeatures")
-
// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
@@ -241,7 +240,7 @@ object DecisionTreeExample {
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
- // (3) Learn DecisionTree
+ // (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier()
@@ -275,62 +274,86 @@ object DecisionTreeExample {
println(s"Training time: $elapsedTime seconds")
// Get the trained Decision Tree from the fitted PipelineModel
- val treeModel: DecisionTreeModel = algo match {
+ algo match {
case "classification" =>
- pipelineModel.getModel[DecisionTreeClassificationModel](
+ val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
case "regression" =>
- pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- if (treeModel.numNodes < 20) {
- println(treeModel.toDebugString) // Print full model.
- } else {
- println(treeModel) // Print model summary.
- }
-
- // Predict on training
- val trainingFullPredictions = pipelineModel.transform(training).cache()
- val trainingPredictions = trainingFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
- // Predict on test data
- val testFullPredictions = pipelineModel.transform(test).cache()
- val testPredictions = testFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
-
- // For classification, print number of classes for reference.
- if (algo == "classification") {
- val numClasses =
- MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
- case Some(n) => n
- case None => throw new RuntimeException(
- "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
+ dt.asInstanceOf[DecisionTreeRegressor])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
}
- println(s"numClasses = $numClasses.")
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
// Evaluate model on training, test data
algo match {
case "classification" =>
- val trainingAccuracy =
- new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
- println(s"Train accuracy = $trainingAccuracy")
- val testAccuracy =
- new MulticlassMetrics(testPredictions.zip(testLabels)).precision
- println(s"Test accuracy = $testAccuracy")
+ println("Training data results:")
+ evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
- val trainingRMSE =
- new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
- println(s"Training root mean squared error (RMSE) = $trainingRMSE")
- val testRMSE =
- new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
- println(s"Test root mean squared error (RMSE) = $testRMSE")
+ println("Training data results:")
+ evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
sc.stop()
}
+
+ /**
+ * Evaluate the given ClassificationModel on data. Print the results.
+ * @param model Must fit ClassificationModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateClassificationModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ // Print number of classes for reference
+ val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "Unknown failure when indexing labels for classification.")
+ }
+ val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
+ println(s" Accuracy ($numClasses classes): $accuracy")
+ }
+
+ /**
+ * Evaluate the given RegressionModel on data. Print the results.
+ * @param model Must fit RegressionModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to RegressionModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateRegressionModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
+ println(s" Root mean squared error (RMSE): $RMSE")
+ }
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
new file mode 100644
index 0000000000..5fccb142d4
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.GBTExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object GBTExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ maxIter: Int = 10,
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GBTExample") {
+ head("GBTExample: an example Gradient-Boosted Trees app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("maxIter")
+ .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"GBTExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"GBTExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn GBT
+ val dt = algo match {
+ case "classification" =>
+ new GBTClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case "regression" =>
+ new GBTRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained GBT from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
new file mode 100644
index 0000000000..9b909324ec
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.RandomForestExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object RandomForestExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 10,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("RandomForestExample") {
+ head("RandomForestExample: an example random forest app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("numTrees")
+ .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(numTrees = x))
+ opt[String]("featureSubsetStrategy")
+ .text(s"number of features to use per node (supported:" +
+ s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
+ s" default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(featureSubsetStrategy = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"RandomForestExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn Random Forest
+ val dt = algo match {
+ case "classification" =>
+ new RandomForestClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case "regression" =>
+ new RandomForestRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Random Forest from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
+ dt.asInstanceOf[RandomForestClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
+ dt.asInstanceOf[RandomForestRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 431ead8c0c..0763a77363 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
+
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{