aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-11-05 10:33:13 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-05 10:33:13 -0800
commit5b3b6f6f5f029164d7749366506e142b104c1d43 (patch)
tree78dc97e3f62d803e0f5cd3837d6a454f27e6e155 /examples/src
parent5f13759d3642ea5b58c12a756e7125ac19aff10e (diff)
downloadspark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.gz
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.bz2
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.zip
[SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java
### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley <joseph@databricks.com> Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java126
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala64
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala146
3 files changed, 317 insertions, 19 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000000..1af2067b2b
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ " <Classification/Regression>");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+ @Override public Boolean call(Tuple2<Double, Double> pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+ @Override public Double call(Tuple2<Double, Double> pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2<Double, Double, Double>() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 49751a3049..63f02cf7b9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -154,20 +154,30 @@ object DecisionTreeRunner {
}
}
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
- val sc = new SparkContext(conf)
-
- println(s"DecisionTreeRunner with parameters:\n$params")
-
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset, number of classes),
+ * where the number of classes is inferred from data (and set to 0 for Regression)
+ */
+ private[mllib] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: Algo,
+ fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
// Load training data and cache it.
- val origExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ val origExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
}
// For classification, re-index classes if needed.
- val (examples, classIndexMap, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -205,14 +215,14 @@ object DecisionTreeRunner {
}
// Create training, test sets.
- val splits = if (params.testInput != "") {
+ val splits = if (testInput != "") {
// Load testInput.
val numFeatures = examples.take(1)(0).features.size
- val origTestExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+ val origTestExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
- params.algo match {
+ algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
@@ -229,17 +239,31 @@ object DecisionTreeRunner {
}
} else {
// Split input into training, test.
- examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ examples.randomSplit(Array(1.0 - fracTest, fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
+
val numTraining = training.count()
val numTest = test.count()
-
println(s"numTraining = $numTraining, numTest = $numTest.")
examples.unpersist(blocking = false)
+ (training, test, numClasses)
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+ params.testInput, params.algo, params.fracTest)
+
val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
@@ -338,7 +362,9 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ private[mllib] def meanSquaredError(
+ tree: WeightedEnsembleModel,
+ data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
new file mode 100644
index 0000000000..9b6db01448
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTrees {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTrees with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}