diff options
author | Liquan Pei <liquanpei@gmail.com> | 2014-08-17 23:29:44 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-17 23:29:44 -0700 |
commit | 3c8fa505900ac158d57de36f6b0fd6da05f8893b (patch) | |
tree | 5c63bedf4fcabd85016640b7dd4b6d4721b45b25 | |
parent | df652ea02a3e42d987419308ef14874300347373 (diff) | |
download | spark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.tar.gz spark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.tar.bz2 spark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.zip |
[SPARK-3097][MLlib] Word2Vec performance improvement
mengxr Please review the code. Adding weights in reduceByKey soon.
Only output model entry for words appeared in the partition before merging and use reduceByKey to combine model. In general, this implementation is 30s or so faster than implementation using big array.
Author: Liquan Pei <liquanpei@gmail.com>
Closes #1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits:
d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model
cad2011 [Liquan Pei] bug fix for synModify array out of bound
083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size
9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal
aa2ab36 [Liquan Pei] use reduceByKey to combine models
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 50 |
1 files changed, 35 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index ecd49ea2ff..d2ae62b482 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging { var syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * vectorSize + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { @@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) + val syn0Local = model._1 + val syn1Local = model._2 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + var index = 0 + while(index < vocabSize) { + if (syn0Modify(index) != 0) { + synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + if (syn1Modify(index) != 0) { + synOut.update(index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + index += 1 + } + Iterator(synOut) } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect() + var i = 0 + while (i < synAgg.length) { + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + i += 1 + } } newSentences.unpersist() |