aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
blob: 786b97ad7b9ece4e5c2d1c646286db415ca26b8a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.util.random

import java.util.Random

import scala.reflect.ClassTag
import scala.collection.mutable.ArrayBuffer

import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark.annotation.DeveloperApi

/**
 * :: DeveloperApi ::
 * A pseudorandom sampler. It is possible to change the sampled item type. For example, we might
 * want to add weights for stratified sampling or importance sampling. Should only use
 * transformations that are tied to the sampler and cannot be applied after sampling.
 *
 * @tparam T item type
 * @tparam U sampled item type
 */
@DeveloperApi
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {

  /** take a random sample */
  def sample(items: Iterator[T]): Iterator[U]

  /** return a copy of the RandomSampler object */
  override def clone: RandomSampler[T, U] =
    throw new NotImplementedError("clone() is not implemented.")
}

private[spark]
object RandomSampler {
  /** Default random number generator used by random samplers. */
  def newDefaultRNG: Random = new XORShiftRandom

  /**
   * Default maximum gap-sampling fraction.
   * For sampling fractions <= this value, the gap sampling optimization will be applied.
   * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster.  The
   * optimal value for this will depend on the RNG.  More expensive RNGs will tend to make
   * the optimal value higher.  The most reliable way to determine this value for a new RNG
   * is to experiment.  When tuning for a new RNG, I would expect a value of 0.5 to be close
   * in most cases, as an initial guess.
   */
  val defaultMaxGapSamplingFraction = 0.4

  /**
   * Default epsilon for floating point numbers sampled from the RNG.
   * The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
   * To guard against errors from taking log(0), a positive epsilon lower bound is applied.
   * A good value for this parameter is at or near the minimum positive floating
   * point value returned by "nextDouble()" (or equivalent), for the RNG being used.
   */
  val rngEpsilon = 5e-11

  /**
   * Sampling fraction arguments may be results of computation, and subject to floating
   * point jitter.  I check the arguments with this epsilon slop factor to prevent spurious
   * warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
   */
  val roundingEpsilon = 1e-6
}

/**
 * :: DeveloperApi ::
 * A sampler based on Bernoulli trials for partitioning a data sequence.
 *
 * @param lb lower bound of the acceptance range
 * @param ub upper bound of the acceptance range
 * @param complement whether to use the complement of the range specified, default to false
 * @tparam T item type
 */
@DeveloperApi
class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false)
  extends RandomSampler[T, T] {

  /** epsilon slop to avoid failure from floating point jitter. */
  require(
    lb <= (ub + RandomSampler.roundingEpsilon),
    s"Lower bound ($lb) must be <= upper bound ($ub)")
  require(
    lb >= (0.0 - RandomSampler.roundingEpsilon),
    s"Lower bound ($lb) must be >= 0.0")
  require(
    ub <= (1.0 + RandomSampler.roundingEpsilon),
    s"Upper bound ($ub) must be <= 1.0")

  private val rng: Random = new XORShiftRandom

  override def setSeed(seed: Long): Unit = rng.setSeed(seed)

  override def sample(items: Iterator[T]): Iterator[T] = {
    if (ub - lb <= 0.0) {
      if (complement) items else Iterator.empty
    } else {
      if (complement) {
        items.filter { item => {
          val x = rng.nextDouble()
          (x < lb) || (x >= ub)
        }}
      } else {
        items.filter { item => {
          val x = rng.nextDouble()
          (x >= lb) && (x < ub)
        }}
      }
    }
  }

  /**
   *  Return a sampler that is the complement of the range specified of the current sampler.
   */
  def cloneComplement(): BernoulliCellSampler[T] =
    new BernoulliCellSampler[T](lb, ub, !complement)

  override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement)
}


/**
 * :: DeveloperApi ::
 * A sampler based on Bernoulli trials.
 *
 * @param fraction the sampling fraction, aka Bernoulli sampling probability
 * @tparam T item type
 */
@DeveloperApi
class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {

  /** epsilon slop to avoid failure from floating point jitter */
  require(
    fraction >= (0.0 - RandomSampler.roundingEpsilon)
      && fraction <= (1.0 + RandomSampler.roundingEpsilon),
    s"Sampling fraction ($fraction) must be on interval [0, 1]")

  private val rng: Random = RandomSampler.newDefaultRNG

  override def setSeed(seed: Long): Unit = rng.setSeed(seed)

  override def sample(items: Iterator[T]): Iterator[T] = {
    if (fraction <= 0.0) {
      Iterator.empty
    } else if (fraction >= 1.0) {
      items
    } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
      new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
    } else {
      items.filter { _ => rng.nextDouble() <= fraction }
    }
  }

  override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction)
}


/**
 * :: DeveloperApi ::
 * A sampler for sampling with replacement, based on values drawn from Poisson distribution.
 *
 * @param fraction the sampling fraction (with replacement)
 * @tparam T item type
 */
@DeveloperApi
class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {

  /** Epsilon slop to avoid failure from floating point jitter. */
  require(
    fraction >= (0.0 - RandomSampler.roundingEpsilon),
    s"Sampling fraction ($fraction) must be >= 0")

  // PoissonDistribution throws an exception when fraction <= 0
  // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
  private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
  private val rngGap = RandomSampler.newDefaultRNG

  override def setSeed(seed: Long) {
    rng.reseedRandomGenerator(seed)
    rngGap.setSeed(seed)
  }

  override def sample(items: Iterator[T]): Iterator[T] = {
    if (fraction <= 0.0) {
      Iterator.empty
    } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
        new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
    } else {
      items.flatMap { item => {
        val count = rng.sample()
        if (count == 0) Iterator.empty else Iterator.fill(count)(item)
      }}
    }
  }

  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
}


private[spark]
class GapSamplingIterator[T: ClassTag](
    var data: Iterator[T],
    f: Double,
    rng: Random = RandomSampler.newDefaultRNG,
    epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {

  require(f > 0.0  &&  f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
  require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

  /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
  private val iterDrop: Int => Unit = {
    val arrayClass = Array.empty[T].iterator.getClass
    val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
    data.getClass match {
      case `arrayClass` =>
        (n: Int) => { data = data.drop(n) }
      case `arrayBufferClass` =>
        (n: Int) => { data = data.drop(n) }
      case _ =>
        (n: Int) => {
          var j = 0
          while (j < n && data.hasNext) {
            data.next()
            j += 1
          }
        }
    }
  }

  override def hasNext: Boolean = data.hasNext

  override def next(): T = {
    val r = data.next()
    advance()
    r
  }

  private val lnq = math.log1p(-f)

  /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
  private def advance(): Unit = {
    val u = math.max(rng.nextDouble(), epsilon)
    val k = (math.log(u) / lnq).toInt
    iterDrop(k)
  }

  /** advance to first sample as part of object construction. */
  advance()
  // Attempting to invoke this closer to the top with other object initialization
  // was causing it to break in strange ways, so I'm invoking it last, which seems to
  // work reliably.
}

private[spark]
class GapSamplingReplacementIterator[T: ClassTag](
    var data: Iterator[T],
    f: Double,
    rng: Random = RandomSampler.newDefaultRNG,
    epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {

  require(f > 0.0, s"Sampling fraction ($f) must be > 0")
  require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

  /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
  private val iterDrop: Int => Unit = {
    val arrayClass = Array.empty[T].iterator.getClass
    val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
    data.getClass match {
      case `arrayClass` =>
        (n: Int) => { data = data.drop(n) }
      case `arrayBufferClass` =>
        (n: Int) => { data = data.drop(n) }
      case _ =>
        (n: Int) => {
          var j = 0
          while (j < n && data.hasNext) {
            data.next()
            j += 1
          }
        }
    }
  }

  /** current sampling value, and its replication factor, as we are sampling with replacement. */
  private var v: T = _
  private var rep: Int = 0

  override def hasNext: Boolean = data.hasNext || rep > 0

  override def next(): T = {
    val r = v
    rep -= 1
    if (rep <= 0) advance()
    r
  }

  /**
   * Skip elements with replication factor zero (i.e. elements that won't be sampled).
   * Samples 'k' from geometric distribution  P(k) = (1-q)(q)^k, where q = e^(-f), that is
   * q is the probabililty of Poisson(0; f)
   */
  private def advance(): Unit = {
    val u = math.max(rng.nextDouble(), epsilon)
    val k = (math.log(u) / (-f)).toInt
    iterDrop(k)
    // set the value and replication factor for the next value
    if (data.hasNext) {
      v = data.next()
      rep = poissonGE1
    }
  }

  private val q = math.exp(-f)

  /**
   * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
   * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
   * http://en.wikipedia.org/wiki/Poisson_distribution
   */
  private def poissonGE1: Int = {
    // simulate that the standard poisson sampling
    // gave us at least one iteration, for a sample of >= 1
    var pp = q + ((1.0 - q) * rng.nextDouble())
    var r = 1

    // now continue with standard poisson sampling algorithm
    pp *= rng.nextDouble()
    while (pp > q) {
      r += 1
      pp *= rng.nextDouble()
    }
    r
  }

  /** advance to first sample as part of object construction. */
  advance()
  // Attempting to invoke this closer to the top with other object initialization
  // was causing it to break in strange ways, so I'm invoking it last, which seems to
  // work reliably.
}