aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
blob: a5bdc2c6d2c94ba1d6b54bb868c5098eaf63ec8b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.mllib.tree.impurity

import org.apache.spark.annotation.{DeveloperApi, Since}

/**
 * Trait for calculating information gain.
 * This trait is used for
 *  (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]]
 *  (b) calculating impurity values from sufficient statistics.
 */
@Since("1.0.0")
trait Impurity extends Serializable {

  /**
   * :: DeveloperApi ::
   * information calculation for multiclass classification
   * @param counts Array[Double] with counts for each label
   * @param totalCount sum of counts for all labels
   * @return information value, or 0 if totalCount = 0
   */
  @Since("1.1.0")
  @DeveloperApi
  def calculate(counts: Array[Double], totalCount: Double): Double

  /**
   * :: DeveloperApi ::
   * information calculation for regression
   * @param count number of instances
   * @param sum sum of labels
   * @param sumSquares summation of squares of the labels
   * @return information value, or 0 if count = 0
   */
  @Since("1.0.0")
  @DeveloperApi
  def calculate(count: Double, sum: Double, sumSquares: Double): Double
}

/**
 * Interface for updating views of a vector of sufficient statistics,
 * in order to compute impurity from a sample.
 * Note: Instances of this class do not hold the data; they operate on views of the data.
 * @param statsSize  Length of the vector of sufficient statistics for one bin.
 */
private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {

  /**
   * Merge the stats from one bin into another.
   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
   * @param offset    Start index of stats for (node, feature, bin) which is modified by the merge.
   * @param otherOffset  Start index of stats for (node, feature, other bin) which is not modified.
   */
  def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = {
    var i = 0
    while (i < statsSize) {
      allStats(offset + i) += allStats(otherOffset + i)
      i += 1
    }
  }

  /**
   * Update stats for one (node, feature, bin) with the given label.
   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
   * @param offset    Start index of stats for this (node, feature, bin).
   */
  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit

  /**
   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
   * @param offset    Start index of stats for this (node, feature, bin).
   */
  def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
}

/**
 * Stores statistics for one (node, feature, bin) for calculating impurity.
 * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific
 * (node, feature, bin).
 * @param stats  Array of sufficient statistics for a (node, feature, bin).
 */
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {

  /**
   * Make a deep copy of this [[ImpurityCalculator]].
   */
  def copy: ImpurityCalculator

  /**
   * Calculate the impurity from the stored sufficient statistics.
   */
  def calculate(): Double

  /**
   * Add the stats from another calculator into this one, modifying and returning this calculator.
   */
  def add(other: ImpurityCalculator): ImpurityCalculator = {
    require(stats.length == other.stats.length,
      s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
        s"  Sizes are ${stats.length} and ${other.stats.length}.")
    var i = 0
    val len = other.stats.length
    while (i < len) {
      stats(i) += other.stats(i)
      i += 1
    }
    this
  }

  /**
   * Subtract the stats from another calculator from this one, modifying and returning this
   * calculator.
   */
  def subtract(other: ImpurityCalculator): ImpurityCalculator = {
    require(stats.length == other.stats.length,
      s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
      s"  Sizes are ${stats.length} and ${other.stats.length}.")
    var i = 0
    val len = other.stats.length
    while (i < len) {
      stats(i) -= other.stats(i)
      i += 1
    }
    this
  }

  /**
   * Number of data points accounted for in the sufficient statistics.
   */
  def count: Long

  /**
   * Prediction which should be made based on the sufficient statistics.
   */
  def predict: Double

  /**
   * Probability of the label given by [[predict]], or -1 if no probability is available.
   */
  def prob(label: Double): Double = -1

  /**
   * Return the index of the largest array element.
   * Fails if the array is empty.
   */
  protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
    val result = array.foldLeft(-1, Double.MinValue, 0) {
      case ((maxIndex, maxValue, currentIndex), currentValue) =>
        if (currentValue > maxValue) {
          (currentIndex, currentValue, currentIndex + 1)
        } else {
          (maxIndex, maxValue, currentIndex + 1)
        }
    }
    if (result._1 < 0) {
      throw new RuntimeException("ImpurityCalculator internal error:" +
        " indexOfLargestArrayElement failed")
    }
    result._1
  }

}

private[spark] object ImpurityCalculator {

  /**
   * Create an [[ImpurityCalculator]] instance of the given impurity type and with
   * the given stats.
   */
  def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
    impurity match {
      case "gini" => new GiniCalculator(stats)
      case "entropy" => new EntropyCalculator(stats)
      case "variance" => new VarianceCalculator(stats)
      case _ =>
        throw new IllegalArgumentException(
          s"ImpurityCalculator builder did not recognize impurity type: $impurity")
    }
  }
}