aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala23
1 files changed, 12 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index c3375ed44f..fc14447053 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
}
val syn0Local = model._1
val syn1Local = model._2
- val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
- var index = 0
- while(index < vocabSize) {
- if (syn0Modify(index) != 0) {
- synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ // Only output modified vectors.
+ Iterator.tabulate(vocabSize) { index =>
+ if (syn0Modify(index) > 0) {
+ Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ } else {
+ None
}
- if (syn1Modify(index) != 0) {
- synOut += ((index + vocabSize,
- syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ }.flatten ++ Iterator.tabulate(vocabSize) { index =>
+ if (syn1Modify(index) > 0) {
+ Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ } else {
+ None
}
- index += 1
- }
- synOut.toIterator
+ }.flatten
}
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)