diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-09-08 20:54:02 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-08 20:54:02 -0700 |
commit | a1573489a37def97b7c26b798898ffbbdc4defa8 (patch) | |
tree | b29962eeed3d3f0d989ab4b6ee8cde74fa4b7ab1 /mllib | |
parent | 52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2 (diff) | |
download | spark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.gz spark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.bz2 spark-a1573489a37def97b7c26b798898ffbbdc4defa8.zip |
[SPARK-10464] [MLLIB] Add WeibullGenerator for RandomDataGenerator
Add WeibullGenerator for RandomDataGenerator.
#8611 need use WeibullGenerator to generate random data based on Weibull distribution.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #8622 from yanboliang/spark-10464.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala | 27 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala | 16 |
2 files changed, 40 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index a2d85a68cd..9eab7efc16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import org.apache.commons.math3.distribution.{ExponentialDistribution, - GammaDistribution, LogNormalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution._ import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -195,3 +194,27 @@ class LogNormalGenerator @Since("1.3.0") ( @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } + +/** + * :: DeveloperApi :: + * Generates i.i.d. samples from the Weibull distribution with the + * given shape and scale parameter. + * + * @param alpha shape parameter for the Weibull distribution. + * @param beta scale parameter for the Weibull distribution. + */ +@DeveloperApi +class WeibullGenerator( + val alpha: Double, + val beta: Double) extends RandomDataGenerator[Double] { + + private val rng = new WeibullDistribution(alpha, beta) + + override def nextValue(): Double = rng.sample() + + override def setSeed(seed: Long): Unit = { + rng.reseedRandomGenerator(seed) + } + + override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index a5ca1518f8..8416771552 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.random -import scala.math +import org.apache.commons.math3.special.Gamma import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite { distributionChecks(gamma, expectedMean, expectedStd, 0.1) } } + + test("WeibullGenerator") { + List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map { + case (alpha: Double, beta: Double) => + val weibull = new WeibullGenerator(alpha, beta) + apiChecks(weibull) + + val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta + val expectedVariance = math.exp( + Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean + val expectedStd = math.sqrt(expectedVariance) + distributionChecks(weibull, expectedMean, expectedStd, 0.1) + } + } } |