diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index adf20dc4b8..53587670a5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,17 +46,15 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense + val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + for (i <- 1 until k) { - // Pick the next center with a probability proportional to cost under current centers - val curCenters = centers.view.take(i) - val sum = points.view.zip(weights).map { case (p, w) => - w * KMeans.pointCost(curCenters, p) - }.sum + val sum = costArray.zip(weights).map(p => p._1 * p._2).sum val r = rand.nextDouble() * sum var cumulativeScore = 0.0 var j = 0 while (j < points.length && cumulativeScore < r) { - cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) + cumulativeScore += weights(j) * costArray(j) j += 1 } if (j == 0) { @@ -66,6 +64,12 @@ private[mllib] object LocalKMeans extends Logging { } else { centers(i) = points(j - 1).toDense } + + // update costArray + for (p <- points.indices) { + costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + } + } // Run up to maxIterations iterations of Lloyd's algorithm |