diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java) | 18 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala | 18 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala (renamed from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala) | 18 |
3 files changed, 22 insertions, 32 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java index 1af2067b2b..4a5ac404ea 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; import org.apache.spark.mllib.util.MLUtils; /** * Classification and regression using gradient-boosted decision trees. */ -public final class JavaGradientBoostedTrees { +public final class JavaGradientBoostedTreesRunner { private static void usage() { - System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" + + System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" + " <Classification/Regression>"); System.exit(-1); } @@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees { if (args.length > 2) { usage(); } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); @@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees { // Note: All features are treated as continuous. BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); boostingStrategy.setNumIterations(10); - boostingStrategy.weakLearnerParams().setMaxDepth(5); + boostingStrategy.treeStrategy().setMaxDepth(5); if (algo.equals("Classification")) { // Compute the number of classes from the data. @@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees { return p.label(); } }).countByValue().size(); - boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD<Double, Double> predictionAndLabel = @@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees { System.out.println("Learned classification tree model:\n" + model); } else if (algo.equals("Regression")) { // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD<Double, Double> predictionAndLabel = diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 63f02cf7b9..98f9d1689c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,11 +22,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} +import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -352,21 +352,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } - - /** - * Calculates the mean squared error for regression. - */ private[mllib] def meanSquaredError( - tree: WeightedEnsembleModel, + model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 9b6db01448..1def8b45a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -21,21 +21,21 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.tree.GradientBoosting +import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] + * ./bin/run-example mllib.GradientBoostedTreesRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. * * Note: This script treats all features as real-valued (not categorical). * To include categorical features, modify categoricalFeaturesInfo. */ -object GradientBoostedTrees { +object GradientBoostedTreesRunner { case class Params( input: String = null, @@ -93,24 +93,24 @@ object GradientBoostedTrees { def run(params: Params) { - val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") + val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params") val sc = new SparkContext(conf) - println(s"GradientBoostedTrees with parameters:\n$params") + println(s"GradientBoostedTreesRunner with parameters:\n$params") // Load training and test data and cache it. val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) val boostingStrategy = BoostingStrategy.defaultParams(params.algo) - boostingStrategy.numClassesForClassification = numClasses + boostingStrategy.treeStrategy.numClassesForClassification = numClasses boostingStrategy.numIterations = params.numIterations - boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth + boostingStrategy.treeStrategy.maxDepth = params.maxDepth val randomSeed = Utils.random.nextInt() if (params.algo == "Classification") { val startTime = System.nanoTime() - val model = GradientBoosting.trainClassifier(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { @@ -127,7 +127,7 @@ object GradientBoostedTrees { println(s"Test accuracy = $testAccuracy") } else if (params.algo == "Regression") { val startTime = System.nanoTime() - val model = GradientBoosting.trainRegressor(training, boostingStrategy) + val model = GradientBoostedTrees.train(training, boostingStrategy) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { |