diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2015-12-01 09:26:58 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-12-01 09:26:58 +0000 |
commit | a0af0e351e45a8be47a6f65efd132eaa4a00c9e4 (patch) | |
tree | 4627576439dac39019d2945108baa992d10b4d33 /mllib/src/main | |
parent | 9693b0d5a55bc1d2da96f04fe2c6de59a8dfcc1b (diff) | |
download | spark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.tar.gz spark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.tar.bz2 spark-a0af0e351e45a8be47a6f65efd132eaa4a00c9e4.zip |
[SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec
jira: https://issues.apache.org/jira/browse/SPARK-11898
syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization.
Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help,
1. decrease the worker memory consumption by 45%.
2. decrease running time by 40%.
This will also help extend the upper limit for Word2Vec.
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #9878 from hhbyyh/w2vBC.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a47f27b0af..655ac0bb55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging { Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() |