aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRJ Nowling <rnowling@gmail.com>2016-01-05 15:05:04 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-05 15:05:04 -0800
commit78015a8b7cc316343e302eeed6fe30af9f2961e8 (patch)
treefd90d1526ccd4671941ab1f4637bd2d88ebb5b88
parent1c6cf1a5639bf5111324e44d93a8c6462958750a (diff)
downloadspark-78015a8b7cc316343e302eeed6fe30af9f2961e8.tar.gz
spark-78015a8b7cc316343e302eeed6fe30af9f2961e8.tar.bz2
spark-78015a8b7cc316343e302eeed6fe30af9f2961e8.zip
[SPARK-12450][MLLIB] Un-persist broadcasted variables in KMeans
SPARK-12450 . Un-persist broadcasted variables in KMeans. Author: RJ Nowling <rnowling@gmail.com> Closes #10415 from rnowling/spark-12450.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 2895db7c90..e47c4db629 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -301,6 +301,8 @@ class KMeans private (
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()
+ bcActiveCenters.unpersist(blocking = false)
+
// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
var changed = false
@@ -419,7 +421,10 @@ class KMeans private (
s0
}
)
+
+ bcNewCenters.unpersist(blocking = false)
preCosts.unpersist(blocking = false)
+
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
pointsWithCosts.flatMap { case (p, c) =>
@@ -448,6 +453,9 @@ class KMeans private (
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
+
+ bcCenters.unpersist(blocking = false)
+
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray