aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-08 20:54:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 20:54:02 -0700
commita1573489a37def97b7c26b798898ffbbdc4defa8 (patch)
treeb29962eeed3d3f0d989ab4b6ee8cde74fa4b7ab1 /mllib/src/test
parent52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2 (diff)
downloadspark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.gz
spark-a1573489a37def97b7c26b798898ffbbdc4defa8.tar.bz2
spark-a1573489a37def97b7c26b798898ffbbdc4defa8.zip
[SPARK-10464] [MLLIB] Add WeibullGenerator for RandomDataGenerator
Add WeibullGenerator for RandomDataGenerator. #8611 need use WeibullGenerator to generate random data based on Weibull distribution. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8622 from yanboliang/spark-10464.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala16
1 files changed, 15 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index a5ca1518f8..8416771552 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.random
-import scala.math
+import org.apache.commons.math3.special.Gamma
import org.apache.spark.SparkFunSuite
import org.apache.spark.util.StatCounter
@@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
distributionChecks(gamma, expectedMean, expectedStd, 0.1)
}
}
+
+ test("WeibullGenerator") {
+ List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map {
+ case (alpha: Double, beta: Double) =>
+ val weibull = new WeibullGenerator(alpha, beta)
+ apiChecks(weibull)
+
+ val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta
+ val expectedVariance = math.exp(
+ Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean
+ val expectedStd = math.sqrt(expectedVariance)
+ distributionChecks(weibull, expectedMean, expectedStd, 0.1)
+ }
+ }
}