diff options
Diffstat (limited to 'examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala')
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala new file mode 100644 index 0000000000..3dc86da8e4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + val boostingStrategy = BoostingStrategy.defaultParams("Regression") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingRegressionModel") + // $example off$ + } +} +// scalastyle:on println |