diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala | 40 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala | 37 |
2 files changed, 34 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index a0faff236e..7bef899a63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -205,31 +205,29 @@ private[spark] object GradientBoostedTrees extends Logging { case _ => data } - val numIterations = trees.length - val evaluationArray = Array.fill(numIterations)(0.0) - val localTreeWeights = treeWeights - - var predictionAndError = computeInitialPredictionAndError( - remappedData, localTreeWeights(0), trees(0), loss) - - evaluationArray(0) = predictionAndError.values.mean() - val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).foreach { nTree => - predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => - val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = localTreeWeights(nTree) - iter.map { case (point, (pred, error)) => - val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight) - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } + val localTreeWeights = treeWeights + val treesIndices = trees.indices + + val dataCount = remappedData.count() + val evaluation = remappedData.map { point => + treesIndices.map { idx => + val prediction = broadcastTrees.value(idx) + .rootNode + .predictImpl(point.features) + .prediction + prediction * localTreeWeights(idx) } - evaluationArray(nTree) = predictionAndError.values.mean() + .scanLeft(0.0)(_ + _).drop(1) + .map(prediction => loss.computeError(prediction, point.label)) } + .aggregate(treesIndices.map(_ => 0.0))( + (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)), + (a, b) => treesIndices.map(idx => a(idx) + b(idx))) + .map(_ / dataCount) - broadcastTrees.unpersist() - evaluationArray + broadcastTrees.destroy() + evaluation.toArray } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index f7d9b22b6f..657ed0a8ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -151,31 +151,24 @@ class GradientBoostedTreesModel @Since("1.2.0") ( case _ => data } - val numIterations = trees.length - val evaluationArray = Array.fill(numIterations)(0.0) - val localTreeWeights = treeWeights - - var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( - remappedData, localTreeWeights(0), trees(0), loss) - - evaluationArray(0) = predictionAndError.values.mean() - val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).foreach { nTree => - predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => - val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = localTreeWeights(nTree) - iter.map { case (point, (pred, error)) => - val newPred = pred + currentTree.predict(point.features) * currentTreeWeight - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } - } - evaluationArray(nTree) = predictionAndError.values.mean() + val localTreeWeights = treeWeights + val treesIndices = trees.indices + + val dataCount = remappedData.count() + val evaluation = remappedData.map { point => + treesIndices + .map(idx => broadcastTrees.value(idx).predict(point.features) * localTreeWeights(idx)) + .scanLeft(0.0)(_ + _).drop(1) + .map(prediction => loss.computeError(prediction, point.label)) } + .aggregate(treesIndices.map(_ => 0.0))( + (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)), + (a, b) => treesIndices.map(idx => a(idx) + b(idx))) + .map(_ / dataCount) - broadcastTrees.unpersist() - evaluationArray + broadcastTrees.destroy() + evaluation.toArray } override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion |