aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala16
1 files changed, 10 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
index adf20dc4b8..53587670a5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
@@ -46,17 +46,15 @@ private[mllib] object LocalKMeans extends Logging {
// Initialize centers by sampling using the k-means++ procedure.
centers(0) = pickWeighted(rand, points, weights).toDense
+ val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0)))
+
for (i <- 1 until k) {
- // Pick the next center with a probability proportional to cost under current centers
- val curCenters = centers.view.take(i)
- val sum = points.view.zip(weights).map { case (p, w) =>
- w * KMeans.pointCost(curCenters, p)
- }.sum
+ val sum = costArray.zip(weights).map(p => p._1 * p._2).sum
val r = rand.nextDouble() * sum
var cumulativeScore = 0.0
var j = 0
while (j < points.length && cumulativeScore < r) {
- cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
+ cumulativeScore += weights(j) * costArray(j)
j += 1
}
if (j == 0) {
@@ -66,6 +64,12 @@ private[mllib] object LocalKMeans extends Logging {
} else {
centers(i) = points(j - 1).toDense
}
+
+ // update costArray
+ for (p <- points.indices) {
+ costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p))
+ }
+
}
// Run up to maxIterations iterations of Lloyd's algorithm