aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-ensembles.md4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala54
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala16
7 files changed, 96 insertions, 46 deletions
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index cbfb682609..7521fb14a7 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
-and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
-iterations.
+and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
+(which gives the error or loss per iteration) to tune the number of iterations.
### Examples
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d1bde15e6b..793dd664c5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean absolute error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- math.abs(err)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = label - prediction
+ math.abs(err)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 55213e6956..51b1aed167 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -50,20 +50,10 @@ object LogLoss extends Loss {
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean log loss of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { case point =>
- val prediction = model.predict(point.features)
- val margin = 2.0 * point.label * prediction
- // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
- 2.0 * MLUtils.log1pExp(-margin)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val margin = 2.0 * label * prediction
+ // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
+ 2.0 * MLUtils.log1pExp(-margin)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index e1169d9f66..357869ff6b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -47,6 +47,18 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Measure of model error on data
*/
- def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map(point => computeError(model.predict(point.features), point.label)).mean()
+ }
+
+ /**
+ * Method to calculate loss when the predictions are already known.
+ * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
+ * predicted values from previously fit trees.
+ * @param prediction Predicted label.
+ * @param label True label.
+ * @return Measure of model error on datapoint.
+ */
+ def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 50ecaa2f86..b990707ca4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -47,18 +47,9 @@ object SquaredError extends Loss {
2.0 * (model.predict(point.features) - point.label)
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean squared error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- err * err
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = prediction - label
+ err * err
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index f160852c69..1950254b2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
}
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param loss evaluation metric.
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ loss: Loss): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+
+ var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
+ val pred = treeWeights(0) * trees(0).predict(i.features)
+ val error = loss.computeError(pred, i.label)
+ (pred, error)
+ }
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ // Avoid the model being copied across numIterations.
+ val broadcastTrees = sc.broadcast(trees)
+ val broadcastWeights = sc.broadcast(treeWeights)
+
+ (1 until numIterations).map { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = broadcastWeights.value(nTree)
+ iter.map {
+ case (point, (pred, error)) => {
+ val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ broadcastWeights.unpersist()
+ evaluationArray
+ }
+
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index b437aeaaf0..55b0bac7d4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
- assert(gbtValidate.numTrees !== numIterations)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
}