aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManish Amde <manish9ue@gmail.com>2014-10-31 18:57:55 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-31 18:57:55 -0700
commit8602195510f5821b37746bb7fa24902f43a1bd93 (patch)
tree3aa02d71f241f42bbe5959a61c9d5530895ce25c
parente07fb6a41ee949f8dba44d5a3b6c0615f27f0eaf (diff)
downloadspark-8602195510f5821b37746bb7fa24902f43a1bd93.tar.gz
spark-8602195510f5821b37746bb7fa24902f43a1bd93.tar.bz2
spark-8602195510f5821b37746bb7fa24902f43a1bd93.zip
[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib
Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with hirakendu that was pending due to decision tree optimizations and random forests work. Ideally, boosting algorithms should work with any base learners. This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default. Here is the task list: - [x] Gradient boosting support - [x] Pluggable loss functions - [x] Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest. - [x] Binary classification support - [x] Support configurable checkpointing – This approach will avoid long lineage chains. - [x] Create classification and regression APIs - [x] Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting. - [x] Unit Tests Future work: + Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function. + BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed. cc: jkbradley hirakendu mengxr etrain atalwalkar chouqin Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Closes #2607 from manishamde/gbt and squashes the following commits: 991c7b5 [Manish Amde] public api ff2a796 [Manish Amde] addressing comments b4c1318 [Manish Amde] removing spaces 8476b6b [Manish Amde] fixing line length 0183cb9 [Manish Amde] fixed naming and formatting issues 1c40c33 [Manish Amde] add newline, removed spaces e33ab61 [Manish Amde] minor comment eadbf09 [Manish Amde] parameter renaming 035a2ed [Manish Amde] jkbradley formatting suggestions 9f7359d [Manish Amde] simplified gbt logic and added more tests 49ba107 [Manish Amde] merged from master eff21fe [Manish Amde] Added gradient boosting tests 3fd0528 [Manish Amde] moved helper methods to new class a32a5ab [Manish Amde] added test for subsampling without replacement 781542a [Manish Amde] added support for fractional subsampling with replacement 3a18cc1 [Manish Amde] cleaned up api for conversion to bagged point and moved tests to it's own test suite 0e81906 [Manish Amde] improving caching unpersisting logic d971f73 [Manish Amde] moved RF code to use WeightedEnsembleModel class fee06d3 [Manish Amde] added weighted ensemble model 1b01943 [Manish Amde] add weights for base learners 9bc6e74 [Manish Amde] adding random seed as parameter d2c8323 [Manish Amde] Merge branch 'master' into gbt 2ae97b7 [Manish Amde] added documentation for the loss classes 9366b8f [Manish Amde] minor: using numTrees instead of trees.size 3b43896 [Manish Amde] added learning rate for prediction 9b2e35e [Manish Amde] Merge branch 'master' into gbt 6a11c02 [manishamde] fixing formatting 823691b [Manish Amde] fixing RF test 1f47941 [Manish Amde] changing access modifier 5b67102 [Manish Amde] shortened parameter list 5ab3796 [Manish Amde] minor reformatting 9155a9d [Manish Amde] consolidated boosting configuration and added public API 631baea [Manish Amde] Merge branch 'master' into gbt 2cb1258 [Manish Amde] public API support 3b8ffc0 [Manish Amde] added documentation 8e10c63 [Manish Amde] modified unpersist strategy f62bc48 [Manish Amde] added unpersist bdca43a [Manish Amde] added timing parameters 2fbc9c7 [Manish Amde] fixing binomial classification prediction 6dd4dd8 [Manish Amde] added support for log loss 9af0231 [Manish Amde] classification attempt 62cc000 [Manish Amde] basic checkpointing 4784091 [Manish Amde] formatting 78ed452 [Manish Amde] added newline and fixed if statement 3973dd1 [Manish Amde] minor indicating subsample is double during comparison aa8fae7 [Manish Amde] minor refactoring 1a8031c [Manish Amde] sampling with replacement f1c9ef7 [Manish Amde] Merge branch 'master' into gbt cdceeef [Manish Amde] added documentation 6251fd5 [Manish Amde] modified method name 5538521 [Manish Amde] disable checkpointing for now 0ae1c0a [Manish Amde] basic gradient boosting code from earlier branches
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala314
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala49
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala109
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala69
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala66
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala63
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala52
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala66
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala115
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala158
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala94
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala132
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala117
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala100
20 files changed, 1331 insertions, 267 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 0890e6263e..f98730366b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -317,7 +317,7 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+ private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 6737a2f417..752ed59a03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
- rfModel.trees(0)
+ rfModel.weakHypotheses(0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
new file mode 100644
index 0000000000..1a847201ce
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.loss.Losses
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+
+/**
+ * :: Experimental ::
+ * A class that implements gradient boosting for regression and binary classification problems.
+ * @param boostingStrategy Parameters for the gradient boosting algorithm
+ */
+@Experimental
+class GradientBoosting (
+ private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ algo match {
+ case Regression => GradientBoosting.boost(input, boostingStrategy)
+ case Classification =>
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoosting.boost(remappedInput, boostingStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+}
+
+
+object GradientBoosting extends Logging {
+
+ /**
+ * Method to train a gradient boosting model.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ * is recommended to clearly specify regression.
+ * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ * is recommended to clearly specify regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting classification model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting regression model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting binary classification model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+ * supported.)
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ val lossType = Losses.fromString(loss)
+ val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
+ learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+ weakLearnerParams)
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting regression model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+ * supported.)
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ val lossType = Losses.fromString(loss)
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
+ learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+ weakLearnerParams)
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
+ numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ weakLearnerParams)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
+ numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ weakLearnerParams)
+ }
+
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param boostingStrategy boosting parameters
+ * @return
+ */
+ private def boost(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ // Initialize gradient boosting parameters
+ val numEstimators = boostingStrategy.numEstimators
+ val baseLearners = new Array[DecisionTreeModel](numEstimators)
+ val baseLearnerWeights = new Array[Double](numEstimators)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ val strategy = boostingStrategy.weakLearnerParams
+
+ // Cache input
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
+
+ // 1. Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(strategy).train(data)
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = 1.0
+ val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
+ Sum)
+ logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ // psuedo-residual for second iteration
+ data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+ point.features))
+
+ var m = 1
+ while (m < numEstimators) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(strategy).train(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw prediction
+ val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
+ baseLearnerWeights.slice(0, m + 1), Regression, Sum)
+ logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ // Update data with pseudo-residuals
+ data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+ point.features))
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+
+ // 3. Output classifier
+ new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
+
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ebbd8e0257..1dcaf91438 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -26,6 +26,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
import org.apache.spark.mllib.tree.impurity.Impurities
@@ -59,7 +60,7 @@ import org.apache.spark.util.Utils
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
- * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
*/
@Experimental
private class RandomForest (
@@ -78,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): RandomForestModel = {
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
val timer = new TimeTracker()
@@ -111,11 +112,20 @@ private class RandomForest (
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val baggedInput = if (numTrees > 1) {
- BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
- } else {
- BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (subsample, withReplacement) = {
+ // TODO: Have a stricter check for RF in the strategy
+ val isRandomForest = numTrees > 1
+ if (isRandomForest) {
+ (1.0, true)
+ } else {
+ (strategy.subsamplingRate, false)
+ }
+ }
+
+ val baggedInput
+ = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
@@ -184,7 +194,8 @@ private class RandomForest (
logInfo(s"$timer")
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- RandomForestModel.build(trees)
+ val treeWeights = Array.fill[Double](numTrees)(1.0)
+ new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
}
}
@@ -205,14 +216,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -243,7 +254,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -254,7 +265,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -273,7 +284,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -293,14 +304,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -330,7 +341,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -340,7 +351,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -358,7 +369,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
new file mode 100644
index 0000000000..501d9ff9ea
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import scala.beans.BeanProperty
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
+
+/**
+ * :: Experimental ::
+ * Stores all the configuration options for the boosting algorithms
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * Default value is 2 (binary classification).
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ * number of discrete values they take. For example, an entry (n ->
+ * k) implies the feature n is categorical with k categories 0,
+ * 1, 2, ... , k-1. It's important to note that features are
+ * zero-indexed.
+ * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
+ * supported.
+ */
+@Experimental
+case class BoostingStrategy(
+ // Required boosting parameters
+ algo: Algo,
+ @BeanProperty var numEstimators: Int,
+ @BeanProperty var loss: Loss,
+ // Optional boosting parameters
+ @BeanProperty var learningRate: Double = 0.1,
+ @BeanProperty var subsamplingRate: Double = 1.0,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
+
+ require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
+ s"$learningRate.")
+ require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
+ s"$learningRate.")
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
+ weakLearnerParams.numClassesForClassification = numClassesForClassification
+ weakLearnerParams.subsamplingRate = subsamplingRate
+
+}
+
+@Experimental
+object BoostingStrategy {
+
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStrategy = defaultWeakLearnerParams(algo)
+ algo match {
+ case Classification =>
+ new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
+ case Regression =>
+ new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+ }
+ }
+
+ /**
+ * Returns default configuration for the weak learner (decision tree) algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for weak learner
+ */
+ def defaultWeakLearnerParams(algo: Algo): Strategy = {
+ // Note: Regression tree used even for classification for GBT.
+ new Strategy(Regression, Variance, 3)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
new file mode 100644
index 0000000000..82889dc00c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: Experimental ::
+ * Enum to select ensemble combining strategy for base learners
+ */
+@DeveloperApi
+object EnsembleCombiningStrategy extends Enumeration {
+ type EnsembleCombiningStrategy = Value
+ val Sum, Average = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index caaccbfb8a..2ed63cf002 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.configuration
+import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
- * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
+ * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
@@ -58,19 +59,21 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
*/
@Experimental
class Strategy (
val algo: Algo,
- val impurity: Impurity,
- val maxDepth: Int,
- val numClassesForClassification: Int = 2,
- val maxBins: Int = 32,
- val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val minInstancesPerNode: Int = 1,
- val minInfoGain: Double = 0.0,
- val maxMemoryInMB: Int = 256) extends Serializable {
+ @BeanProperty var impurity: Impurity,
+ @BeanProperty var maxDepth: Int,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var maxBins: Int = 32,
+ @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var minInstancesPerNode: Int = 1,
+ @BeanProperty var minInfoGain: Double = 0.0,
+ @BeanProperty var maxMemoryInMB: Int = 256,
+ @BeanProperty var subsamplingRate: Double = 1) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index e7a2127c5d..089010c81f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -21,13 +21,14 @@ import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Internal representation of a datapoint which belongs to several subsamples of the same dataset,
* particularly for bagging (e.g., for random forests).
*
* This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsample.
+ * number of times which this instance appears in each subsamplingRate.
* E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
@@ -44,22 +45,65 @@ private[tree] object BaggedPoint {
/**
* Convert an input dataset into its BaggedPoint representation,
- * choosing subsample counts for each instance.
- * Each subsample has the same number of instances as the original dataset,
- * and is created by subsampling with replacement.
- * @param input Input dataset.
- * @param numSubsamples Number of subsamples of this RDD to take.
- * @param seed Random seed.
- * @return BaggedPoint dataset representation
+ * choosing subsamplingRate counts for each instance.
+ * Each subsamplingRate has the same number of instances as the original dataset,
+ * and is created by subsampling without replacement.
+ * @param input Input dataset.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param withReplacement Sampling with/without replacement.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation.
*/
- def convertToBaggedRDD[Datum](
+ def convertToBaggedRDD[Datum] (
input: RDD[Datum],
+ subsamplingRate: Double,
numSubsamples: Int,
+ withReplacement: Boolean,
seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ if (withReplacement) {
+ convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+ } else {
+ if (numSubsamples == 1 && subsamplingRate == 1.0) {
+ convertToBaggedRDDWithoutSampling(input)
+ } else {
+ convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val rng = new XORShiftRandom
+ rng.setSeed(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleWeights = new Array[Double](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ val x = rng.nextDouble()
+ subsampleWeights(subsampleIndex) = {
+ if (x < subsamplingRate) 1.0 else 0.0
+ }
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleWeights)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+ input: RDD[Datum],
+ subsample: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
- // TODO: Support different sampling rates, and sampling without replacement.
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
- val poisson = new PoissonDistribution(1.0)
+ val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
@@ -73,7 +117,8 @@ private[tree] object BaggedPoint {
}
}
- def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+ private def convertToBaggedRDDWithoutSampling[Datum] (
+ input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1.0)))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
new file mode 100644
index 0000000000..d111ffe30e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least absolute error loss calculation.
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: |y - F|
+ * Negative gradient: sign(y - F)
+ */
+@DeveloperApi
+object AbsoluteError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * absolute error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val sumOfAbsolutes = data.map { y =>
+ val err = model.predict(y.features) - y.label
+ math.abs(err)
+ }.sum()
+ sumOfAbsolutes / data.count()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
new file mode 100644
index 0000000000..6f3d4340f0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: log(1 + exp(-2yF)), y in {-1, 1}
+ * Negative gradient: 2y / ( 1 + exp(2yF))
+ */
+@DeveloperApi
+object LogLoss extends Loss {
+
+ /**
+ * Method to calculate the loss gradients for the gradient boosting calculation for binary
+ * classification
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ val prediction = model.predict(point.features)
+ 1.0 / (1.0 + math.exp(-prediction)) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
+ wrongPredictions / data.count
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
new file mode 100644
index 0000000000..5580866c87
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
+ */
+@DeveloperApi
+trait Loss extends Serializable {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation.
+ * @param model Model of the weak learner.
+ * @param point Instance of the training dataset.
+ * @return Loss gradient.
+ */
+ def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
new file mode 100644
index 0000000000..42c9ead988
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+object Losses {
+
+ def fromString(name: String): Loss = name match {
+ case "leastSquaresError" => SquaredError
+ case "leastAbsoluteError" => AbsoluteError
+ case "logLoss" => LogLoss
+ case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name")
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
new file mode 100644
index 0000000000..4349fefef2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: (y - F)**2/2
+ * Negative gradient: y - F
+ */
+@DeveloperApi
+object SquaredError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * squares error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ model.predict(point.features) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = model.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
deleted file mode 100644
index 6a22e2abe5..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.rdd.RDD
-
-/**
- * :: Experimental ::
- * Random forest model for classification or regression.
- * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
- * aggregate predictions.
- * @param trees Trees which make up this forest. This cannot be empty.
- * @param algo algorithm type -- classification or regression
- */
-@Experimental
-class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
-
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- def predict(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- trees.foreach { tree =>
- val prediction = tree.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- trees.map(_.predict(features)).sum / trees.size
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = {
- features.map(x => predict(x))
- }
-
- /**
- * Get number of trees in forest.
- */
- def numTrees: Int = trees.size
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + trees.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
-}
-
-private[tree] object RandomForestModel {
-
- def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
- val algo: Algo = trees(0).algo
- require(trees.forall(_.algo == algo),
- "RandomForestModel cannot combine trees which have different output types" +
- " (classification/regression).")
- new RandomForestModel(trees, algo)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
new file mode 100644
index 0000000000..7b052d9163
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable
+
+@Experimental
+class WeightedEnsembleModel(
+ val weakHypotheses: Array[DecisionTreeModel],
+ val weakHypothesisWeights: Array[Double],
+ val algo: Algo,
+ val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
+ s". Number of weakHypotheses = $weakHypotheses")
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictRaw(features: Vector): Double = {
+ val treePredictions = weakHypotheses.map(learner => learner.predict(features))
+ if (numWeakHypotheses == 1){
+ treePredictions(0)
+ } else {
+ var prediction = treePredictions(0)
+ var index = 1
+ while (index < numWeakHypotheses) {
+ prediction += weakHypothesisWeights(index) * treePredictions(index)
+ index += 1
+ }
+ prediction
+ }
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ algo match {
+ case Regression => predictRaw(features)
+ case Classification => {
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (predictRaw(features) > 0 ) 1.0 else 0.0
+ }
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Predict values for a single data point.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ private def predictByAveraging(features: Vector): Double = {
+ algo match {
+ case Classification =>
+ val predictionToCount = new mutable.HashMap[Int, Int]()
+ weakHypotheses.foreach { learner =>
+ val prediction = learner.predict(features).toInt
+ predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+ }
+ predictionToCount.maxBy(_._2)._1
+ case Regression =>
+ weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
+ }
+ }
+
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ combiningStrategy match {
+ case Sum => predictBySumming(features)
+ case Average => predictByAveraging(features)
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
+ case Regression =>
+ s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numWeakHypotheses: Int = weakHypotheses.size
+
+ // TODO: Remove these helpers methods once class is generalized to support any base learning
+ // algorithms.
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8fc5e111bb..c579cb5854 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -493,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(rootNode1.rightNode.nonEmpty)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
// Single group second level tree construction.
val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
@@ -786,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
@@ -829,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
new file mode 100644
index 0000000000..effb7b8259
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.util.StatCounter
+
+import scala.collection.mutable
+
+object EnsembleTestHelper {
+
+ /**
+ * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+ * epsilon of the expected values.
+ * @param data Every element of the data should be an i.i.d. sample from some distribution.
+ */
+ def testRandomArrays(
+ data: Array[Array[Double]],
+ numCols: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double) {
+ val values = new mutable.ArrayBuffer[Double]()
+ data.foreach { row =>
+ assert(row.size == numCols)
+ values ++= row
+ }
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ def validateClassifier(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
+ def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](numInstances)
+ for (i <- 0 until numInstances) {
+ val label = if (i < numInstances / 10) {
+ 0.0
+ } else if (i < numInstances / 2) {
+ 1.0
+ } else if (i < numInstances * 0.9) {
+ 0.0
+ } else {
+ 1.0
+ }
+ val features = Array.fill[Double](numFeatures)(i.toDouble)
+ arr(i) = new LabeledPoint(label, Vectors.dense(features))
+ }
+ arr
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
new file mode 100644
index 0000000000..970fff8221
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[GradientBoosting]].
+ */
+class GradientBoostingSuite extends FunSuite with LocalSparkContext {
+
+ test("Regression with continuous features: SquaredError") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+ test("Regression with continuous features: Absolute Error") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+
+ test("Binary classification with continuous features: Log Loss") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+}
+
+object GradientBoostingSuite {
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ val testCombinations
+ = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index d3eff59aa0..10c046e07f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -25,45 +25,20 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.util.StatCounter
/**
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends FunSuite with LocalSparkContext {
- test("BaggedPoint RDD: without subsampling") {
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
- val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
- baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
- }
- }
-
- test("BaggedPoint RDD: with subsampling") {
- val numSubsamples = 100
- val (expectedMean, expectedStddev) = (1.0, 1.0)
-
- val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
- val rdd = sc.parallelize(arr)
- seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
- RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
- expectedStddev, epsilon = 0.01)
- }
- }
-
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
@@ -73,12 +48,12 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateClassifier(rf, arr, 0.9)
+ EnsembleTestHelper.validateClassifier(rf, arr, 0.9)
DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
// Make sure trees are the same.
@@ -88,7 +63,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
test("Regression with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
@@ -99,12 +74,12 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateRegressor(rf, arr, 0.01)
+ EnsembleTestHelper.validateRegressor(rf, arr, 0.01)
DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
// Make sure trees are the same.
@@ -113,7 +88,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
test("Binary classification with continuous features: subsampling features") {
val numFeatures = 50
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -187,77 +162,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
- RandomForestSuite.validateClassifier(model, arr, 0.0)
+ EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
}
-object RandomForestSuite {
-
- /**
- * Aggregates all values in data, and tests whether the empirical mean and stddev are within
- * epsilon of the expected values.
- * @param data Every element of the data should be an i.i.d. sample from some distribution.
- */
- def testRandomArrays(
- data: Array[Array[Double]],
- numCols: Int,
- expectedMean: Double,
- expectedStddev: Double,
- epsilon: Double) {
- val values = new mutable.ArrayBuffer[Double]()
- data.foreach { row =>
- assert(row.size == numCols)
- values ++= row
- }
- val stats = new StatCounter(values)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
- }
-
- def validateClassifier(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
- prediction != expected.label
- }
- val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy,
- s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
- }
-
- def validateRegressor(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredMSE: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
- }
- def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
- val numInstances = 1000
- val arr = new Array[LabeledPoint](numInstances)
- for (i <- 0 until numInstances) {
- val label = if (i < numInstances / 10) {
- 0.0
- } else if (i < numInstances / 2) {
- 1.0
- } else if (i < numInstances * 0.9) {
- 0.0
- } else {
- 1.0
- }
- val features = Array.fill[Double](numFeatures)(i.toDouble)
- arr(i) = new LabeledPoint(label, Vectors.dense(features))
- }
- arr
- }
-
-}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
new file mode 100644
index 0000000000..c0a62e0043
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.EnsembleTestHelper
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[BaggedPoint]].
+ */
+class BaggedPointSuite extends FunSuite with LocalSparkContext {
+
+ test("BaggedPoint RDD: without subsampling") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false)
+ baggedRDD.collect().foreach { baggedPoint =>
+ assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample)))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+}