aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-08 20:54:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 20:54:02 -0700
commita1573489a37def97b7c26b798898ffbbdc4defa8 (patch)
treeb29962eeed3d3f0d989ab4b6ee8cde74fa4b7ab1 /mllib
parent52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2 (diff)
downloadspark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.gz
spark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.bz2
spark-a1573489a37def97b7c26b798898ffbbdc4defa8.zip
[SPARK-10464] [MLLIB] Add WeibullGenerator for RandomDataGenerator
Add WeibullGenerator for RandomDataGenerator. #8611 need use WeibullGenerator to generate random data based on Weibull distribution. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8622 from yanboliang/spark-10464.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala16
2 files changed, 40 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index a2d85a68cd..9eab7efc16 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.random
-import org.apache.commons.math3.distribution.{ExponentialDistribution,
- GammaDistribution, LogNormalDistribution, PoissonDistribution}
+import org.apache.commons.math3.distribution._
import org.apache.spark.annotation.{Since, DeveloperApi}
import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
@@ -195,3 +194,27 @@ class LogNormalGenerator @Since("1.3.0") (
@Since("1.3.0")
override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std)
}
+
+/**
+ * :: DeveloperApi ::
+ * Generates i.i.d. samples from the Weibull distribution with the
+ * given shape and scale parameter.
+ *
+ * @param alpha shape parameter for the Weibull distribution.
+ * @param beta scale parameter for the Weibull distribution.
+ */
+@DeveloperApi
+class WeibullGenerator(
+ val alpha: Double,
+ val beta: Double) extends RandomDataGenerator[Double] {
+
+ private val rng = new WeibullDistribution(alpha, beta)
+
+ override def nextValue(): Double = rng.sample()
+
+ override def setSeed(seed: Long): Unit = {
+ rng.reseedRandomGenerator(seed)
+ }
+
+ override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index a5ca1518f8..8416771552 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.random
-import scala.math
+import org.apache.commons.math3.special.Gamma
import org.apache.spark.SparkFunSuite
import org.apache.spark.util.StatCounter
@@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
distributionChecks(gamma, expectedMean, expectedStd, 0.1)
}
}
+
+ test("WeibullGenerator") {
+ List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map {
+ case (alpha: Double, beta: Double) =>
+ val weibull = new WeibullGenerator(alpha, beta)
+ apiChecks(weibull)
+
+ val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta
+ val expectedVariance = math.exp(
+ Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean
+ val expectedStd = math.sqrt(expectedVariance)
+ distributionChecks(weibull, expectedMean, expectedStd, 0.1)
+ }
+ }
}