aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-12-01 09:26:58 +0000
committerSean Owen <sowen@cloudera.com>2015-12-01 09:26:58 +0000
commita0af0e351e45a8be47a6f65efd132eaa4a00c9e4 (patch)
tree4627576439dac39019d2945108baa992d10b4d33 /mllib/src/main/scala/org
parent9693b0d5a55bc1d2da96f04fe2c6de59a8dfcc1b (diff)
downloadspark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.tar.gz
spark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.tar.bz2
spark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.zip
[SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec
jira: https://issues.apache.org/jira/browse/SPARK-11898 syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization. Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help, 1. decrease the worker memory consumption by 45%. 2. decrease running time by 40%. This will also help extend the upper limit for Word2Vec. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #9878 from hhbyyh/w2vBC.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala7
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a47f27b0af..655ac0bb55 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging {
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate
+
for (k <- 1 to numIterations) {
+ val bcSyn0Global = sc.broadcast(syn0Global)
+ val bcSyn1Global = sc.broadcast(syn1Global)
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
- val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
+ val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
@@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging {
}
i += 1
}
+ bcSyn0Global.unpersist(false)
+ bcSyn1Global.unpersist(false)
}
newSentences.unpersist()