aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authornate.crosswhite <nate.crosswhite@stresearch.com>2015-01-21 10:32:10 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-21 10:32:10 -0800
commit7450a992b3b543a373c34fc4444a528954ac4b4a (patch)
tree1e3c63168367b3a25335f34dc0d8a58ffa39477f /mllib
parentaa1e22b17b4ce885febe6970a2451c7d17d0acfb (diff)
downloadspark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.gz
spark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.bz2
spark-7450a992b3b543a373c34fc4444a528954ac4b4a.zip
[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed
This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark Author: nate.crosswhite <nate.crosswhite@stresearch.com> Author: nxwhite-str <nxwhite-str@users.noreply.github.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3610 from nxwhite-str/master and squashes the following commits: a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed 7668124 [Xiangrui Meng] minor updates f8d5928 [nate.crosswhite] Addressing PR issues 277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test 616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala48
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala21
3 files changed, 66 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 555da8c7e7..430d763ef7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
k: Int,
maxIterations: Int,
runs: Int,
- initializationMode: String): KMeansModel = {
+ initializationMode: String,
+ seed: java.lang.Long): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
+
+ if (seed != null) kMeansAlg.setSeed(seed)
+
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 54c301d3e9..6b5c934f01 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -43,13 +43,14 @@ class KMeans private (
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
- private var epsilon: Double) extends Serializable with Logging {
+ private var epsilon: Double,
+ private var seed: Long) extends Serializable with Logging {
/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
- * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
+ * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
*/
- def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
+ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int): this.type = {
@@ -112,6 +113,12 @@ class KMeans private (
this
}
+ /** Set the random seed for cluster initialization. */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -255,7 +262,7 @@ class KMeans private (
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
@@ -273,7 +280,7 @@ class KMeans private (
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Initialize each run's center to a random point
- val seed = new XORShiftRandom().nextInt()
+ val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
@@ -333,7 +340,32 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
- * @param data training points stored as `RDD[Array[Double]]`
+ * @param data training points stored as `RDD[Vector]`
+ * @param k number of clusters
+ * @param maxIterations max number of iterations
+ * @param runs number of parallel runs, defaults to 1. The best model is returned.
+ * @param initializationMode initialization model, either "random" or "k-means||" (default).
+ * @param seed random seed value for cluster initialization
+ */
+ def train(
+ data: RDD[Vector],
+ k: Int,
+ maxIterations: Int,
+ runs: Int,
+ initializationMode: String,
+ seed: Long): KMeansModel = {
+ new KMeans().setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ .setSeed(seed)
+ .run(data)
+ }
+
+ /**
+ * Trains a k-means model using the given set of parameters.
+ *
+ * @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 9ebef8466c..caee591700 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.size === 3)
}
+ test("deterministic initialization") {
+ // Create a large-ish set of points for clustering
+ val points = List.tabulate(1000)(n => Vectors.dense(n, n))
+ val rdd = sc.parallelize(points, 3)
+
+ for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
+ // Create three deterministic models and compare cluster means
+ val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers1 = model1.clusterCenters
+
+ val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers2 = model2.clusterCenters
+
+ centers1.zip(centers2).foreach { case (c1, c2) =>
+ assert(c1 ~== c2 absTol 1E-14)
+ }
+ }
+ }
+
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),