aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala
blob: ff27f28459e265ca1cd3e0f46d8d53c1086ba448 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.mllib.stat.test

import java.io.Serializable

import scala.language.implicitConversions
import scala.math.pow

import com.twitter.chill.MeatLocker
import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues
import org.apache.commons.math3.stat.inference.TTest

import org.apache.spark.internal.Logging
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter

/**
 * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests
 * should extend [[StreamingTestMethod]] and introduce a new entry in
 * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]]
 */
private[stat] sealed trait StreamingTestMethod extends Serializable {

  val methodName: String
  val nullHypothesis: String

  protected type SummaryPairStream =
    DStream[(StatCounter, StatCounter)]

  /**
   * Perform streaming 2-sample statistical significance testing.
   *
   * @param sampleSummaries stream pairs of summary statistics for the 2 samples
   * @return stream of rest results
   */
  def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult]

  /**
   * Implicit adapter to convert between streaming summary statistics type and the type required by
   * the t-testing libraries.
   */
  protected implicit def toApacheCommonsStats(
      summaryStats: StatCounter): StatisticalSummaryValues = {
    new StatisticalSummaryValues(
      summaryStats.mean,
      summaryStats.variance,
      summaryStats.count,
      summaryStats.max,
      summaryStats.min,
      summaryStats.mean * summaryStats.count
    )
  }
}

/**
 * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean.
 * This test does not assume equal variance between the two samples and does not assume equal
 * sample size.
 *
 * @see http://en.wikipedia.org/wiki/Welch%27s_t_test
 */
private[stat] object WelchTTest extends StreamingTestMethod with Logging {

  override final val methodName = "Welch's 2-sample t-test"
  override final val nullHypothesis = "Both groups have same mean"

  private final val tTester = MeatLocker(new TTest())

  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
    data.map[StreamingTestResult]((test _).tupled)

  private def test(
      statsA: StatCounter,
      statsB: StatCounter): StreamingTestResult = {
    def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = {
      val s1 = sample1.getVariance
      val n1 = sample1.getN
      val s2 = sample2.getVariance
      val n2 = sample2.getN

      val a = pow(s1, 2) / n1
      val b = pow(s2, 2) / n2

      pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1)))
    }

    new StreamingTestResult(
      tTester.get.tTest(statsA, statsB),
      welchDF(statsA, statsB),
      tTester.get.t(statsA, statsB),
      methodName,
      nullHypothesis
    )
  }
}

/**
 * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal
 * mean. This test assumes equal variance between the two samples and does not assume equal sample
 * size. For unequal variances, Welch's t-test should be used instead.
 *
 * @see http://en.wikipedia.org/wiki/Student%27s_t-test
 */
private[stat] object StudentTTest extends StreamingTestMethod with Logging {

  override final val methodName = "Student's 2-sample t-test"
  override final val nullHypothesis = "Both groups have same mean"

  private final val tTester = MeatLocker(new TTest())

  override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] =
    data.map[StreamingTestResult]((test _).tupled)

  private def test(
      statsA: StatCounter,
      statsB: StatCounter): StreamingTestResult = {
    def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double =
      sample1.getN + sample2.getN - 2

    new StreamingTestResult(
      tTester.get.homoscedasticTTest(statsA, statsB),
      studentDF(statsA, statsB),
      tTester.get.homoscedasticT(statsA, statsB),
      methodName,
      nullHypothesis
    )
  }
}

/**
 * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between
 * strings used in [[StreamingTest]] configuration and actual method implementation.
 *
 * Currently supported tests: `welch`, `student`.
 */
private[stat] object StreamingTestMethod {
  // Note: after new `StreamingTestMethod`s are implemented, please update this map.
  private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map(
    "welch" -> WelchTTest,
    "student" -> StudentTTest)

  def getTestMethodFromName(method: String): StreamingTestMethod =
    TEST_NAME_TO_OBJECT.get(method) match {
      case Some(test) => test
      case None =>
        throw new IllegalArgumentException(
          "Unrecognized method name. Supported streaming test methods: "
            + TEST_NAME_TO_OBJECT.keys.mkString(", "))
    }
}