aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMahmoud Rawas <mhmoudr@gmail.com>2016-06-29 13:12:17 +0100
committerSean Owen <sowen@cloudera.com>2016-06-29 13:12:17 +0100
commit393db655c3c43155305fbba1b2f8c48a95f18d93 (patch)
tree109640f22c85ac6803be334548492687b31aa07f
parent21385d02a987bcee1198103e447c019f7a769d68 (diff)
downloadspark-393db655c3c43155305fbba1b2f8c48a95f18d93.tar.gz
spark-393db655c3c43155305fbba1b2f8c48a95f18d93.tar.bz2
spark-393db655c3c43155305fbba1b2f8c48a95f18d93.zip
[SPARK-15858][ML] Fix calculating error by tree stack over flow prob…
## What changes were proposed in this pull request? What changes were proposed in this pull request? Improving evaluateEachIteration function in mllib as it fails when trying to calculate error by tree for a model that has more than 500 trees ## How was this patch tested? the batch tested on productions data set (2K rows x 2K features) training a gradient boosted model without validation with 1000 maxIteration settings, then trying to produce the error by tree, the new patch was able to perform the calculation within 30 seconds, while previously it was take hours then fail. **PS**: It would be better if this PR can be cherry picked into release branches 1.6.1 and 2.0 Author: Mahmoud Rawas <mhmoudr@gmail.com> Author: Mahmoud Rawas <Mahmoud.Rawas@quantium.com.au> Closes #13624 from mhmoudr/SPARK-15858.master.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala37
2 files changed, 34 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index a0faff236e..7bef899a63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -205,31 +205,29 @@ private[spark] object GradientBoostedTrees extends Logging {
case _ => data
}
- val numIterations = trees.length
- val evaluationArray = Array.fill(numIterations)(0.0)
- val localTreeWeights = treeWeights
-
- var predictionAndError = computeInitialPredictionAndError(
- remappedData, localTreeWeights(0), trees(0), loss)
-
- evaluationArray(0) = predictionAndError.values.mean()
-
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).foreach { nTree =>
- predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
- val currentTree = broadcastTrees.value(nTree)
- val currentTreeWeight = localTreeWeights(nTree)
- iter.map { case (point, (pred, error)) =>
- val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight)
- val newError = loss.computeError(newPred, point.label)
- (newPred, newError)
- }
+ val localTreeWeights = treeWeights
+ val treesIndices = trees.indices
+
+ val dataCount = remappedData.count()
+ val evaluation = remappedData.map { point =>
+ treesIndices.map { idx =>
+ val prediction = broadcastTrees.value(idx)
+ .rootNode
+ .predictImpl(point.features)
+ .prediction
+ prediction * localTreeWeights(idx)
}
- evaluationArray(nTree) = predictionAndError.values.mean()
+ .scanLeft(0.0)(_ + _).drop(1)
+ .map(prediction => loss.computeError(prediction, point.label))
}
+ .aggregate(treesIndices.map(_ => 0.0))(
+ (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+ (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+ .map(_ / dataCount)
- broadcastTrees.unpersist()
- evaluationArray
+ broadcastTrees.destroy()
+ evaluation.toArray
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index f7d9b22b6f..657ed0a8ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -151,31 +151,24 @@ class GradientBoostedTreesModel @Since("1.2.0") (
case _ => data
}
- val numIterations = trees.length
- val evaluationArray = Array.fill(numIterations)(0.0)
- val localTreeWeights = treeWeights
-
- var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
- remappedData, localTreeWeights(0), trees(0), loss)
-
- evaluationArray(0) = predictionAndError.values.mean()
-
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).foreach { nTree =>
- predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
- val currentTree = broadcastTrees.value(nTree)
- val currentTreeWeight = localTreeWeights(nTree)
- iter.map { case (point, (pred, error)) =>
- val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
- val newError = loss.computeError(newPred, point.label)
- (newPred, newError)
- }
- }
- evaluationArray(nTree) = predictionAndError.values.mean()
+ val localTreeWeights = treeWeights
+ val treesIndices = trees.indices
+
+ val dataCount = remappedData.count()
+ val evaluation = remappedData.map { point =>
+ treesIndices
+ .map(idx => broadcastTrees.value(idx).predict(point.features) * localTreeWeights(idx))
+ .scanLeft(0.0)(_ + _).drop(1)
+ .map(prediction => loss.computeError(prediction, point.label))
}
+ .aggregate(treesIndices.map(_ => 0.0))(
+ (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+ (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+ .map(_ / dataCount)
- broadcastTrees.unpersist()
- evaluationArray
+ broadcastTrees.destroy()
+ evaluation.toArray
}
override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion