From 5b3b6f6f5f029164d7749366506e142b104c1d43 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 5 Nov 2014 10:33:13 -0800 Subject: [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes --- .../spark/examples/mllib/DecisionTreeRunner.scala | 64 ++++++--- .../examples/mllib/GradientBoostedTrees.scala | 146 +++++++++++++++++++++ 2 files changed, 191 insertions(+), 19 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala (limited to 'examples/src/main/scala/org') diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 49751a3049..63f02cf7b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -154,20 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") - val sc = new SparkContext(conf) - - println(s"DecisionTreeRunner with parameters:\n$params") - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -205,14 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. val numFeatures = examples.take(1)(0).features.size - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -229,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -338,7 +362,9 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + tree: WeightedEnsembleModel, + data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala new file mode 100644 index 0000000000..9b6db01448 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTrees { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTrees with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} -- cgit v1.2.3