aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-07-30 08:07:22 -0700
committerSean Owen <sowen@cloudera.com>2016-07-30 08:07:22 -0700
commitbce354c1d4e2b97b1159913085e9883a26bc605a (patch)
tree7bd30b78e12322a6ce8e300219b0a25546cdccde /mllib
parent0dc4310b470c7e4355c0da67ca3373c3013cc9dd (diff)
downloadspark-bce354c1d4e2b97b1159913085e9883a26bc605a.tar.gz
spark-bce354c1d4e2b97b1159913085e9883a26bc605a.tar.bz2
spark-bce354c1d4e2b97b1159913085e9883a26bc605a.zip
[SPARK-16696][ML][MLLIB] destroy KMeans bcNewCenters when loop finished and update code where should release unused broadcast/RDD in proper time
## What changes were proposed in this pull request? update unused broadcast in KMeans/Word2Vec, use destroy(false) to release memory in time. and several place destroy() update to destroy(false) so that it will be async-called, it will better than blocking called. and update bcNewCenters in KMeans to make it destroy in correct time. I use a list to store all historical `bcNewCenters` generated in each loop iteration and delay them to release at the end of loop. fix TODO in `BisectingKMeans.run` "unpersist old indices", Implements the pattern "persist current step RDD, and unpersist previous one" in the loop iteration. ## How was this patch tested? Existing tests. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14333 from WeichenXu123/broadvar_unpersist_to_destroy.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala10
3 files changed, 17 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index f1664ce4ab..e6b89712e2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -165,6 +165,8 @@ class BisectingKMeans private (
val random = new Random(seed)
var numLeafClustersNeeded = k - 1
var level = 1
+ var preIndices: RDD[Long] = null
+ var indices: RDD[Long] = null
while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) {
// Divisible clusters are sufficiently large and have non-trivial cost.
var divisibleClusters = activeClusters.filter { case (_, summary) =>
@@ -194,8 +196,9 @@ class BisectingKMeans private (
newClusters = summarize(d, newAssignments)
newClusterCenters = newClusters.mapValues(_.center).map(identity)
}
- // TODO: Unpersist old indices.
- val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
+ if (preIndices != null) preIndices.unpersist()
+ preIndices = indices
+ indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
.persist(StorageLevel.MEMORY_AND_DISK)
assignments = indices.zip(vectors)
inactiveClusters ++= activeClusters
@@ -208,6 +211,7 @@ class BisectingKMeans private (
}
level += 1
}
+ if(indices != null) indices.unpersist()
val clusters = activeClusters ++ inactiveClusters
val root = buildTree(clusters)
new BisectingKMeansModel(root)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 9a3d64fca5..de9fa4aebf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.Since
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
import org.apache.spark.ml.util.Instrumentation
@@ -309,7 +310,7 @@ class KMeans private (
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()
- bcActiveCenters.unpersist(blocking = false)
+ bcActiveCenters.destroy(blocking = false)
// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
@@ -402,8 +403,10 @@ class KMeans private (
// to their squared distance from that run's centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
+ var bcNewCentersList = ArrayBuffer[Broadcast[_]]()
while (step < initializationSteps) {
val bcNewCenters = data.context.broadcast(newCenters)
+ bcNewCentersList += bcNewCenters
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Array.tabulate(runs) { r =>
@@ -453,6 +456,7 @@ class KMeans private (
mergeNewCenters()
costs.unpersist(blocking = false)
+ bcNewCentersList.foreach(_.destroy(false))
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
@@ -464,7 +468,7 @@ class KMeans private (
}
}.reduceByKey(_ + _).collectAsMap()
- bcCenters.unpersist(blocking = false)
+ bcCenters.destroy(blocking = false)
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index bc75646d53..908198740b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -430,13 +430,13 @@ class Word2Vec extends Serializable with Logging {
}
i += 1
}
- bcSyn0Global.unpersist(false)
- bcSyn1Global.unpersist(false)
+ bcSyn0Global.destroy(false)
+ bcSyn1Global.destroy(false)
}
newSentences.unpersist()
- expTable.destroy()
- bcVocab.destroy()
- bcVocabHash.destroy()
+ expTable.destroy(false)
+ bcVocab.destroy(false)
+ bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)