aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index a7beb81980..8ff0b83e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -253,16 +253,14 @@ class KMeans private (
}
val centers = initialModel match {
- case Some(kMeansCenters) => {
+ case Some(kMeansCenters) =>
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
- }
- case None => {
+ case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
- }
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
@@ -390,6 +388,8 @@ class KMeans private (
// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
+ // Could be empty if data is empty; fail with a better message early:
+ require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
/** Merges new centers to centers. */