From 393db655c3c43155305fbba1b2f8c48a95f18d93 Mon Sep 17 00:00:00 2001 From: Mahmoud Rawas Date: Wed, 29 Jun 2016 13:12:17 +0100 Subject: [SPARK-15858][ML] Fix calculating error by tree stack over flow prob… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? What changes were proposed in this pull request? Improving evaluateEachIteration function in mllib as it fails when trying to calculate error by tree for a model that has more than 500 trees ## How was this patch tested? the batch tested on productions data set (2K rows x 2K features) training a gradient boosted model without validation with 1000 maxIteration settings, then trying to produce the error by tree, the new patch was able to perform the calculation within 30 seconds, while previously it was take hours then fail. **PS**: It would be better if this PR can be cherry picked into release branches 1.6.1 and 2.0 Author: Mahmoud Rawas Author: Mahmoud Rawas Closes #13624 from mhmoudr/SPARK-15858.master. --- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 40 ++++++++++------------ .../mllib/tree/model/treeEnsembleModels.scala | 37 ++++++++------------ 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index a0faff236e..7bef899a63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -205,31 +205,29 @@ private[spark] object GradientBoostedTrees extends Logging { case _ => data } - val numIterations = trees.length - val evaluationArray = Array.fill(numIterations)(0.0) - val localTreeWeights = treeWeights - - var predictionAndError = computeInitialPredictionAndError( - remappedData, localTreeWeights(0), trees(0), loss) - - evaluationArray(0) = predictionAndError.values.mean() - val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).foreach { nTree => - predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => - val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = localTreeWeights(nTree) - iter.map { case (point, (pred, error)) => - val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight) - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } + val localTreeWeights = treeWeights + val treesIndices = trees.indices + + val dataCount = remappedData.count() + val evaluation = remappedData.map { point => + treesIndices.map { idx => + val prediction = broadcastTrees.value(idx) + .rootNode + .predictImpl(point.features) + .prediction + prediction * localTreeWeights(idx) } - evaluationArray(nTree) = predictionAndError.values.mean() + .scanLeft(0.0)(_ + _).drop(1) + .map(prediction => loss.computeError(prediction, point.label)) } + .aggregate(treesIndices.map(_ => 0.0))( + (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)), + (a, b) => treesIndices.map(idx => a(idx) + b(idx))) + .map(_ / dataCount) - broadcastTrees.unpersist() - evaluationArray + broadcastTrees.destroy() + evaluation.toArray } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index f7d9b22b6f..657ed0a8ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -151,31 +151,24 @@ class GradientBoostedTreesModel @Since("1.2.0") ( case _ => data } - val numIterations = trees.length - val evaluationArray = Array.fill(numIterations)(0.0) - val localTreeWeights = treeWeights - - var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( - remappedData, localTreeWeights(0), trees(0), loss) - - evaluationArray(0) = predictionAndError.values.mean() - val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).foreach { nTree => - predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => - val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = localTreeWeights(nTree) - iter.map { case (point, (pred, error)) => - val newPred = pred + currentTree.predict(point.features) * currentTreeWeight - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } - } - evaluationArray(nTree) = predictionAndError.values.mean() + val localTreeWeights = treeWeights + val treesIndices = trees.indices + + val dataCount = remappedData.count() + val evaluation = remappedData.map { point => + treesIndices + .map(idx => broadcastTrees.value(idx).predict(point.features) * localTreeWeights(idx)) + .scanLeft(0.0)(_ + _).drop(1) + .map(prediction => loss.computeError(prediction, point.label)) } + .aggregate(treesIndices.map(_ => 0.0))( + (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)), + (a, b) => treesIndices.map(idx => a(idx) + b(idx))) + .map(_ / dataCount) - broadcastTrees.unpersist() - evaluationArray + broadcastTrees.destroy() + evaluation.toArray } override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion -- cgit v1.2.3