aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-11-05 10:33:13 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-05 10:33:13 -0800
commit5b3b6f6f5f029164d7749366506e142b104c1d43 (patch)
tree78dc97e3f62d803e0f5cd3837d6a454f27e6e155 /examples/src/main/java
parent5f13759d3642ea5b58c12a756e7125ac19aff10e (diff)
downloadspark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.gz
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.bz2
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.zip
[SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java
### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley <joseph@databricks.com> Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java126
1 files changed, 126 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000000..1af2067b2b
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ " <Classification/Regression>");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+ @Override public Boolean call(Tuple2<Double, Double> pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+ @Override public Double call(Tuple2<Double, Double> pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2<Double, Double, Double>() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}