aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-04-13 15:36:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-13 15:36:33 -0700
commit2a55cb41bf7da1786be2c76b8af398da8fedb44b (patch)
tree0266cd957b8503144f0a038f180686b4869ccc1c /mllib/src
parent3a205bbd9e352668a020c3146391e1e4441467af (diff)
downloadspark-2a55cb41bf7da1786be2c76b8af398da8fedb44b.tar.gz
spark-2a55cb41bf7da1786be2c76b8af398da8fedb44b.tar.bz2
spark-2a55cb41bf7da1786be2c76b8af398da8fedb44b.zip
[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation
The previous PR https://github.com/apache/spark/pull/4906 helped to extract the learning curve giving the error for each iteration. This continues the work refactoring some code and extending the same logic during training and validation. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5330 from MechCoder/spark-5972 and squashes the following commits: 0b5d659 [MechCoder] minor 32d409d [MechCoder] EvaluateeachIteration and training cache should follow different paths d542bb0 [MechCoder] Remove unused imports and docs 58f4932 [MechCoder] Remove unpersist 70d3b4c [MechCoder] Broadcast for each tree 5869533 [MechCoder] Access broadcasted values locally and other minor changes 923dbf6 [MechCoder] [SPARK-5972] Cache residuals and gradient in GBT during training and validation
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala77
6 files changed, 105 insertions, 53 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a9c93e181e..c02c79f094 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
validate: Boolean): GradientBoostedTreesModel = {
-
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = 1.0
- val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
- logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ baseLearnerWeights(0) = firstTreeWeight
+ val startingModel = new GradientBoostedTreesModel(
+ Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
+
+ var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
+ computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
- var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
+ var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
+ computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+ var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
- // psuedo-residual for second iteration
- data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
- point.features))
+ // pseudo-residual for second iteration
+ data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+
var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
@@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new GradientBoostedTreesModel(
- Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
- logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ Regression, baseLearners.slice(0, m + 1),
+ baseLearnerWeights.slice(0, m + 1))
+
+ predError = GradientBoostedTreesModel.updatePredictionError(
+ input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ logDebug("error of gbt = " + predError.values.mean())
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
- val currentValidateError = loss.computeError(partialModel, validationInput)
+
+ validatePredError = GradientBoostedTreesModel.updatePredictionError(
+ validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+ val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
@@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging {
}
}
// Update data with pseudo-residuals
- data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
- point.features))
+ data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
m += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 793dd664c5..6f570b4e09 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* absolute error calculation.
* The gradient with respect to F(x) is: sign(F(x) - y)
- * @param model Ensemble model
- * @param point Instance of the training dataset
+ * @param prediction Predicted label.
+ * @param label True label.
* @return Loss gradient
*/
- override def gradient(
- model: TreeEnsembleModel,
- point: LabeledPoint): Double = {
- if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+ override def gradient(prediction: Double, label: Double): Double = {
+ if (label - prediction < 0) 1.0 else -1.0
}
override def computeError(prediction: Double, label: Double): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 51b1aed167..24ee9f3d51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -39,15 +39,12 @@ object LogLoss extends Loss {
* Method to calculate the loss gradients for the gradient boosting calculation for binary
* classification
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
- * @param model Ensemble model
- * @param point Instance of the training dataset
+ * @param prediction Predicted label.
+ * @param label True label.
* @return Loss gradient
*/
- override def gradient(
- model: TreeEnsembleModel,
- point: LabeledPoint): Double = {
- val prediction = model.predict(point.features)
- - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
+ override def gradient(prediction: Double, label: Double): Double = {
+ - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
override def computeError(prediction: Double, label: Double): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 357869ff6b..d3b82b752f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -31,13 +31,11 @@ trait Loss extends Serializable {
/**
* Method to calculate the gradients for the gradient boosting calculation.
- * @param model Model of the weak learner.
- * @param point Instance of the training dataset.
+ * @param prediction Predicted feature
+ * @param label true label.
* @return Loss gradient.
*/
- def gradient(
- model: TreeEnsembleModel,
- point: LabeledPoint): Double
+ def gradient(prediction: Double, label: Double): Double
/**
* Method to calculate error of the base learner for the gradient boosting calculation.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index b990707ca4..58857ae15e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -37,14 +37,12 @@ object SquaredError extends Loss {
* Method to calculate the gradients for the gradient boosting calculation for least
* squares error calculation.
* The gradient with respect to F(x) is: - 2 (y - F(x))
- * @param model Ensemble model
- * @param point Instance of the training dataset
+ * @param prediction Predicted label.
+ * @param label True label.
* @return Loss gradient
*/
- override def gradient(
- model: TreeEnsembleModel,
- point: LabeledPoint): Double = {
- 2.0 * (model.predict(point.features) - point.label)
+ override def gradient(prediction: Double, label: Double): Double = {
+ 2.0 * (prediction - label)
}
override def computeError(prediction: Double, label: Double): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 1950254b2a..fef3d2acb2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -130,35 +130,28 @@ class GradientBoostedTreesModel(
val numIterations = trees.length
val evaluationArray = Array.fill(numIterations)(0.0)
+ val localTreeWeights = treeWeights
+
+ var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
+ remappedData, localTreeWeights(0), trees(0), loss)
- var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
- val pred = treeWeights(0) * trees(0).predict(i.features)
- val error = loss.computeError(pred, i.label)
- (pred, error)
- }
evaluationArray(0) = predictionAndError.values.mean()
- // Avoid the model being copied across numIterations.
val broadcastTrees = sc.broadcast(trees)
- val broadcastWeights = sc.broadcast(treeWeights)
-
(1 until numIterations).map { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
- val currentTreeWeight = broadcastWeights.value(nTree)
- iter.map {
- case (point, (pred, error)) => {
- val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
- val newError = loss.computeError(newPred, point.label)
- (newPred, newError)
- }
+ val currentTreeWeight = localTreeWeights(nTree)
+ iter.map { case (point, (pred, error)) =>
+ val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
}
}
evaluationArray(nTree) = predictionAndError.values.mean()
}
broadcastTrees.unpersist()
- broadcastWeights.unpersist()
evaluationArray
}
@@ -166,6 +159,58 @@ class GradientBoostedTreesModel(
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
+ /**
+ * Compute the initial predictions and errors for a dataset for the first
+ * iteration of gradient boosting.
+ * @param data: training data.
+ * @param initTreeWeight: learning rate assigned to the first tree.
+ * @param initTree: first DecisionTreeModel.
+ * @param loss: evaluation metric.
+ * @return a RDD with each element being a zip of the prediction and error
+ * corresponding to every sample.
+ */
+ def computeInitialPredictionAndError(
+ data: RDD[LabeledPoint],
+ initTreeWeight: Double,
+ initTree: DecisionTreeModel,
+ loss: Loss): RDD[(Double, Double)] = {
+ data.map { lp =>
+ val pred = initTreeWeight * initTree.predict(lp.features)
+ val error = loss.computeError(pred, lp.label)
+ (pred, error)
+ }
+ }
+
+ /**
+ * Update a zipped predictionError RDD
+ * (as obtained with computeInitialPredictionAndError)
+ * @param data: training data.
+ * @param predictionAndError: predictionError RDD
+ * @param treeWeight: Learning rate.
+ * @param tree: Tree using which the prediction and error should be updated.
+ * @param loss: evaluation metric.
+ * @return a RDD with each element being a zip of the prediction and error
+ * corresponding to each sample.
+ */
+ def updatePredictionError(
+ data: RDD[LabeledPoint],
+ predictionAndError: RDD[(Double, Double)],
+ treeWeight: Double,
+ tree: DecisionTreeModel,
+ loss: Loss): RDD[(Double, Double)] = {
+
+ val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
+ iter.map {
+ case (lp, (pred, error)) => {
+ val newPred = pred + tree.predict(lp.features) * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
+ }
+ }
+ }
+ newPredError
+ }
+
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName