diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index a7beb81980..8ff0b83e8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -253,16 +253,14 @@ class KMeans private ( } val centers = initialModel match { - case Some(kMeansCenters) => { + case Some(kMeansCenters) => Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) - } - case None => { + case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { initKMeansParallel(data) } - } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + @@ -390,6 +388,8 @@ class KMeans private ( // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq + // Could be empty if data is empty; fail with a better message early: + require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data") val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) /** Merges new centers to centers. */ |