aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-01-09 13:00:15 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-09 13:00:15 -0800
commit7e8e62aec11c43c983055adc475b96006412199a (patch)
treee6d84d6bcea44178f01d9e810a236e49632dcbc4 /mllib
parent454fe129ee97b859bf079db8b9158e115a219ad5 (diff)
downloadspark-7e8e62aec11c43c983055adc475b96006412199a.tar.gz
spark-7e8e62aec11c43c983055adc475b96006412199a.tar.bz2
spark-7e8e62aec11c43c983055adc475b96006412199a.zip
[SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic
Issues: * From JIRA: GaussianMixtureEM uses randomness but does not take a random seed. It should take one as a parameter. * This also makes the test suite flaky since initialization can fail due to stochasticity. Fix: * Add random seed * Use it in test suite CC: mengxr tgaloppo Author: Joseph K. Bradley <joseph@databricks.com> Closes #3981 from jkbradley/gmm-seed and squashes the following commits: f0df4fd [Joseph K. Bradley] Added seed parameter to GMM. Updated test suite to use seed to prevent flakiness
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala14
2 files changed, 27 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index 3a6c0e681e..b3c5631cc4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.util.Utils
/**
* This class performs expectation maximization for multivariate Gaussian
@@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
class GaussianMixtureEM private (
private var k: Int,
private var convergenceTol: Double,
- private var maxIterations: Int) extends Serializable {
+ private var maxIterations: Int,
+ private var seed: Long) extends Serializable {
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
- def this() = this(2, 0.01, 100)
+ def this() = this(2, 0.01, 100, Utils.random.nextLong())
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
@@ -100,11 +102,21 @@ class GaussianMixtureEM private (
this
}
- /** Return the largest change in log-likelihood at which convergence is
- * considered to have occurred.
+ /**
+ * Return the largest change in log-likelihood at which convergence is
+ * considered to have occurred.
*/
def getConvergenceTol: Double = convergenceTol
-
+
+ /** Set the random seed */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
+ /** Return the random seed */
+ def getSeed: Long = seed
+
/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
@@ -113,7 +125,7 @@ class GaussianMixtureEM private (
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
// Get length of the input vectors
- val d = breezeData.first.length
+ val d = breezeData.first().length
// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
@@ -126,7 +138,7 @@ class GaussianMixtureEM private (
})
case None => {
- val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
+ val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 23feb82874..9da5495741 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
-
- val gmm = new GaussianMixtureEM().setK(1).run(data)
-
- assert(gmm.weight(0) ~== Ew absTol 1E-5)
- assert(gmm.mu(0) ~== Emu absTol 1E-5)
- assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+
+ val seeds = Array(314589, 29032897, 50181, 494821, 4660)
+ seeds.foreach { seed =>
+ val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
+ assert(gmm.weight(0) ~== Ew absTol 1E-5)
+ assert(gmm.mu(0) ~== Emu absTol 1E-5)
+ assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ }
}
test("two clusters") {