aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java)18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala (renamed from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala)18
3 files changed, 22 insertions, 32 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
index 1af2067b2b..4a5ac404ea 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;
/**
* Classification and regression using gradient-boosted decision trees.
*/
-public final class JavaGradientBoostedTrees {
+public final class JavaGradientBoostedTreesRunner {
private static void usage() {
- System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
@@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees {
if (args.length > 2) {
usage();
}
- SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
@@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees {
// Note: All features are treated as continuous.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
boostingStrategy.setNumIterations(10);
- boostingStrategy.weakLearnerParams().setMaxDepth(5);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
if (algo.equals("Classification")) {
// Compute the number of classes from the data.
@@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees {
return p.label();
}
}).countByValue().size();
- boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 63f02cf7b9..98f9d1689c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -352,21 +352,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
-
- /**
- * Calculates the mean squared error for regression.
- */
private[mllib] def meanSquaredError(
- tree: WeightedEnsembleModel,
+ model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
- val err = tree.predict(y.features) - y.label
+ val err = model.predict(y.features) - y.label
err * err
}.mean()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 9b6db01448..1def8b45a2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -21,21 +21,21 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
- * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify categoricalFeaturesInfo.
*/
-object GradientBoostedTrees {
+object GradientBoostedTreesRunner {
case class Params(
input: String = null,
@@ -93,24 +93,24 @@ object GradientBoostedTrees {
def run(params: Params) {
- val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
val sc = new SparkContext(conf)
- println(s"GradientBoostedTrees with parameters:\n$params")
+ println(s"GradientBoostedTreesRunner with parameters:\n$params")
// Load training and test data and cache it.
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
- boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.treeStrategy.numClassesForClassification = numClasses
boostingStrategy.numIterations = params.numIterations
- boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+ boostingStrategy.treeStrategy.maxDepth = params.maxDepth
val randomSeed = Utils.random.nextInt()
if (params.algo == "Classification") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
@@ -127,7 +127,7 @@ object GradientBoostedTrees {
println(s"Test accuracy = $testAccuracy")
} else if (params.algo == "Regression") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {