aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala37
2 files changed, 34 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index a0faff236e..7bef899a63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -205,31 +205,29 @@ private[spark] object GradientBoostedTrees extends Logging {
case _ => data
}
- val numIterations = trees.length
- val evaluationArray = Array.fill(numIterations)(0.0)
- val localTreeWeights = treeWeights
-
- var predictionAndError = computeInitialPredictionAndError(
- remappedData, localTreeWeights(0), trees(0), loss)
-
- evaluationArray(0) = predictionAndError.values.mean()
-
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).foreach { nTree =>
- predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
- val currentTree = broadcastTrees.value(nTree)
- val currentTreeWeight = localTreeWeights(nTree)
- iter.map { case (point, (pred, error)) =>
- val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight)
- val newError = loss.computeError(newPred, point.label)
- (newPred, newError)
- }
+ val localTreeWeights = treeWeights
+ val treesIndices = trees.indices
+
+ val dataCount = remappedData.count()
+ val evaluation = remappedData.map { point =>
+ treesIndices.map { idx =>
+ val prediction = broadcastTrees.value(idx)
+ .rootNode
+ .predictImpl(point.features)
+ .prediction
+ prediction * localTreeWeights(idx)
}
- evaluationArray(nTree) = predictionAndError.values.mean()
+ .scanLeft(0.0)(_ + _).drop(1)
+ .map(prediction => loss.computeError(prediction, point.label))
}
+ .aggregate(treesIndices.map(_ => 0.0))(
+ (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+ (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+ .map(_ / dataCount)
- broadcastTrees.unpersist()
- evaluationArray
+ broadcastTrees.destroy()
+ evaluation.toArray
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index f7d9b22b6f..657ed0a8ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -151,31 +151,24 @@ class GradientBoostedTreesModel @Since("1.2.0") (
case _ => data
}
- val numIterations = trees.length
- val evaluationArray = Array.fill(numIterations)(0.0)
- val localTreeWeights = treeWeights
-
- var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
- remappedData, localTreeWeights(0), trees(0), loss)
-
- evaluationArray(0) = predictionAndError.values.mean()
-
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).foreach { nTree =>
- predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
- val currentTree = broadcastTrees.value(nTree)
- val currentTreeWeight = localTreeWeights(nTree)
- iter.map { case (point, (pred, error)) =>
- val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
- val newError = loss.computeError(newPred, point.label)
- (newPred, newError)
- }
- }
- evaluationArray(nTree) = predictionAndError.values.mean()
+ val localTreeWeights = treeWeights
+ val treesIndices = trees.indices
+
+ val dataCount = remappedData.count()
+ val evaluation = remappedData.map { point =>
+ treesIndices
+ .map(idx => broadcastTrees.value(idx).predict(point.features) * localTreeWeights(idx))
+ .scanLeft(0.0)(_ + _).drop(1)
+ .map(prediction => loss.computeError(prediction, point.label))
}
+ .aggregate(treesIndices.map(_ => 0.0))(
+ (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+ (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+ .map(_ / dataCount)
- broadcastTrees.unpersist()
- evaluationArray
+ broadcastTrees.destroy()
+ evaluation.toArray
}
override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion