aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala139
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala238
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala248
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala1
4 files changed, 568 insertions, 58 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 2cd515c89d..9002e99d82 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -22,10 +22,9 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
-import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
@@ -64,8 +63,6 @@ object DecisionTreeExample {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
- numTrees: Int = 1,
- featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
@@ -123,8 +120,8 @@ object DecisionTreeExample {
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest > 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
@@ -200,9 +197,18 @@ object DecisionTreeExample {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}
- val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings)
+ val training = dataframes(0).cache()
+ val test = dataframes(1).cache()
- (dataframes(0), dataframes(1))
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
+
+ (training, test)
}
def run(params: Params) {
@@ -217,13 +223,6 @@ object DecisionTreeExample {
val (training: DataFrame, test: DataFrame) =
loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
- val numTraining = training.count()
- val numTest = test.count()
- val numFeatures = training.select("features").first().getAs[Vector](0).size
- println("Loaded data:")
- println(s" numTraining = $numTraining, numTest = $numTest")
- println(s" numFeatures = $numFeatures")
-
// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
@@ -241,7 +240,7 @@ object DecisionTreeExample {
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
- // (3) Learn DecisionTree
+ // (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier()
@@ -275,62 +274,86 @@ object DecisionTreeExample {
println(s"Training time: $elapsedTime seconds")
// Get the trained Decision Tree from the fitted PipelineModel
- val treeModel: DecisionTreeModel = algo match {
+ algo match {
case "classification" =>
- pipelineModel.getModel[DecisionTreeClassificationModel](
+ val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
case "regression" =>
- pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- if (treeModel.numNodes < 20) {
- println(treeModel.toDebugString) // Print full model.
- } else {
- println(treeModel) // Print model summary.
- }
-
- // Predict on training
- val trainingFullPredictions = pipelineModel.transform(training).cache()
- val trainingPredictions = trainingFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
- // Predict on test data
- val testFullPredictions = pipelineModel.transform(test).cache()
- val testPredictions = testFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
-
- // For classification, print number of classes for reference.
- if (algo == "classification") {
- val numClasses =
- MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
- case Some(n) => n
- case None => throw new RuntimeException(
- "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
+ dt.asInstanceOf[DecisionTreeRegressor])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
}
- println(s"numClasses = $numClasses.")
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
// Evaluate model on training, test data
algo match {
case "classification" =>
- val trainingAccuracy =
- new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
- println(s"Train accuracy = $trainingAccuracy")
- val testAccuracy =
- new MulticlassMetrics(testPredictions.zip(testLabels)).precision
- println(s"Test accuracy = $testAccuracy")
+ println("Training data results:")
+ evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
- val trainingRMSE =
- new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
- println(s"Training root mean squared error (RMSE) = $trainingRMSE")
- val testRMSE =
- new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
- println(s"Test root mean squared error (RMSE) = $testRMSE")
+ println("Training data results:")
+ evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
sc.stop()
}
+
+ /**
+ * Evaluate the given ClassificationModel on data. Print the results.
+ * @param model Must fit ClassificationModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateClassificationModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ // Print number of classes for reference
+ val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "Unknown failure when indexing labels for classification.")
+ }
+ val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
+ println(s" Accuracy ($numClasses classes): $accuracy")
+ }
+
+ /**
+ * Evaluate the given RegressionModel on data. Print the results.
+ * @param model Must fit RegressionModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to RegressionModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateRegressionModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
+ println(s" Root mean squared error (RMSE): $RMSE")
+ }
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
new file mode 100644
index 0000000000..5fccb142d4
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.GBTExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object GBTExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ maxIter: Int = 10,
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GBTExample") {
+ head("GBTExample: an example Gradient-Boosted Trees app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("maxIter")
+ .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"GBTExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"GBTExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn GBT
+ val dt = algo match {
+ case "classification" =>
+ new GBTClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case "regression" =>
+ new GBTRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained GBT from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
new file mode 100644
index 0000000000..9b909324ec
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.RandomForestExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object RandomForestExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 10,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("RandomForestExample") {
+ head("RandomForestExample: an example random forest app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("numTrees")
+ .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(numTrees = x))
+ opt[String]("featureSubsetStrategy")
+ .text(s"number of features to use per node (supported:" +
+ s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
+ s" default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(featureSubsetStrategy = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"RandomForestExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn Random Forest
+ val dt = algo match {
+ case "classification" =>
+ new RandomForestClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case "regression" =>
+ new RandomForestRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Random Forest from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
+ dt.asInstanceOf[RandomForestClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
+ dt.asInstanceOf[RandomForestRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 431ead8c0c..0763a77363 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
+
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{