aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiquan Pei <liquanpei@gmail.com>2014-08-17 23:29:44 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-17 23:29:44 -0700
commit3c8fa505900ac158d57de36f6b0fd6da05f8893b (patch)
tree5c63bedf4fcabd85016640b7dd4b6d4721b45b25 /mllib
parentdf652ea02a3e42d987419308ef14874300347373 (diff)
downloadspark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.tar.gz
spark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.tar.bz2
spark-3c8fa505900ac158d57de36f6b0fd6da05f8893b.zip
[SPARK-3097][MLlib] Word2Vec performance improvement
mengxr Please review the code. Adding weights in reduceByKey soon. Only output model entry for words appeared in the partition before merging and use reduceByKey to combine model. In general, this implementation is 30s or so faster than implementation using big array. Author: Liquan Pei <liquanpei@gmail.com> Closes #1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits: d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model cad2011 [Liquan Pei] bug fix for synModify array out of bound 083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size 9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal aa2ab36 [Liquan Pei] use reduceByKey to combine models
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala50
1 files changed, 35 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index ecd49ea2ff..d2ae62b482 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
/**
* Entry in vocabulary
@@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
var syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)
-
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+ val syn0Modify = new Array[Int](vocabSize)
+ val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
@@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
- val l2 = bcVocab.value(word).point(d) * vectorSize
+ val inner = bcVocab.value(word).point(d)
+ val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
@@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
+ syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+ syn0Modify(lastWord) += 1
}
}
a += 1
@@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
}
(syn0, syn1, lwc, wc)
}
- Iterator(model)
+ val syn0Local = model._1
+ val syn1Local = model._2
+ val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
+ var index = 0
+ while(index < vocabSize) {
+ if (syn0Modify(index) != 0) {
+ synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ }
+ if (syn1Modify(index) != 0) {
+ synOut.update(index + vocabSize,
+ syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ }
+ index += 1
+ }
+ Iterator(synOut)
}
- val (aggSyn0, aggSyn1, _, _) =
- partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
- val n = syn0_1.length
- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
- blas.sscal(n, weight1, syn0_1, 1)
- blas.sscal(n, weight1, syn1_1, 1)
- blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
- blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+ val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
+ blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+ v1
+ }.collect()
+ var i = 0
+ while (i < synAgg.length) {
+ val index = synAgg(i)._1
+ if (index < vocabSize) {
+ Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
+ } else {
+ Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
- syn0Global = aggSyn0
- syn1Global = aggSyn1
+ i += 1
+ }
}
newSentences.unpersist()