aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3b13e52a7b..74d5d7ba10 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
- binData(shift + (2 *(numBins - 2 - splitIndex))) +
+ binData(shift + (2 *(numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
splitIndex += 1
@@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
- binData(shift + (3 * (numBins - 2 - splitIndex))) +
+ binData(shift + (3 * (numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
splitIndex += 1