aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
blob: 61091bb803e49670208ee13cfd3a2e0da97427a9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.tree.impl

import org.apache.spark.mllib.tree.impurity._



/**
 * DecisionTree statistics aggregator for a node.
 * This holds a flat array of statistics for a set of (features, bins)
 * and helps with indexing.
 * This class is abstract to support learning with and without feature subsampling.
 */
private[spark] class DTStatsAggregator(
    val metadata: DecisionTreeMetadata,
    featureSubset: Option[Array[Int]]) extends Serializable {

  /**
   * [[ImpurityAggregator]] instance specifying the impurity type.
   */
  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
    case Gini => new GiniAggregator(metadata.numClasses)
    case Entropy => new EntropyAggregator(metadata.numClasses)
    case Variance => new VarianceAggregator()
    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
  }

  /**
   * Number of elements (Double values) used for the sufficient statistics of each bin.
   */
  private val statsSize: Int = impurityAggregator.statsSize

  /**
   * Number of bins for each feature.  This is indexed by the feature index.
   */
  private val numBins: Array[Int] = {
    if (featureSubset.isDefined) {
      featureSubset.get.map(metadata.numBins(_))
    } else {
      metadata.numBins
    }
  }

  /**
   * Offset for each feature for calculating indices into the [[allStats]] array.
   */
  private val featureOffsets: Array[Int] = {
    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
  }

  /**
   * Total number of elements stored in this aggregator
   */
  private val allStatsSize: Int = featureOffsets.last

  /**
   * Flat array of elements.
   * Index for start of stats for a (feature, bin) is:
   *   index = featureOffsets(featureIndex) + binIndex * statsSize
   */
  private val allStats: Array[Double] = new Array[Double](allStatsSize)

  /**
   * Array of parent node sufficient stats.
   *
   * Note: this is necessary because stats for the parent node are not available
   *       on the first iteration of tree learning.
   */
  private val parentStats: Array[Double] = new Array[Double](statsSize)

  /**
   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
   *
   * @param featureOffset  This is a pre-computed (node, feature) offset
   *                           from [[getFeatureOffset]].
   */
  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
  }

  /**
   * Get an [[ImpurityCalculator]] for the parent node.
   */
  def getParentImpurityCalculator(): ImpurityCalculator = {
    impurityAggregator.getCalculator(parentStats, 0)
  }

  /**
   * Update the stats for a given (feature, bin) for ordered features, using the given label.
   */
  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
    val i = featureOffsets(featureIndex) + binIndex * statsSize
    impurityAggregator.update(allStats, i, label, instanceWeight)
  }

  /**
   * Update the parent node stats using the given label.
   */
  def updateParent(label: Double, instanceWeight: Double): Unit = {
    impurityAggregator.update(parentStats, 0, label, instanceWeight)
  }

  /**
   * Faster version of [[update]].
   * Update the stats for a given (feature, bin), using the given label.
   *
   * @param featureOffset  This is a pre-computed feature offset
   *                           from [[getFeatureOffset]].
   */
  def featureUpdate(
      featureOffset: Int,
      binIndex: Int,
      label: Double,
      instanceWeight: Double): Unit = {
    impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
      label, instanceWeight)
  }

  /**
   * Pre-compute feature offset for use with [[featureUpdate]].
   * For ordered features only.
   */
  def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)

  /**
   * For a given feature, merge the stats for two bins.
   *
   * @param featureOffset  This is a pre-computed feature offset
   *                           from [[getFeatureOffset]].
   * @param binIndex  The other bin is merged into this bin.
   * @param otherBinIndex  This bin is not modified.
   */
  def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
    impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
      featureOffset + otherBinIndex * statsSize)
  }

  /**
   * Merge this aggregator with another, and returns this aggregator.
   * This method modifies this aggregator in-place.
   */
  def merge(other: DTStatsAggregator): DTStatsAggregator = {
    require(allStatsSize == other.allStatsSize,
      s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
        + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
    var i = 0
    // TODO: Test BLAS.axpy
    while (i < allStatsSize) {
      allStats(i) += other.allStats(i)
      i += 1
    }

    require(statsSize == other.statsSize,
      s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
        s"stats vectors. This aggregator's parent stats are length $statsSize, " +
        s"but the other is ${other.statsSize}.")
    var j = 0
    while (j < statsSize) {
      parentStats(j) += other.parentStats(j)
      j += 1
    }

    this
  }
}