aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-19 22:16:22 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-19 22:16:22 -0700
commit0a984aa155fb7f532fe87620dcf1a2814c5b8b49 (patch)
tree1ccd1295c933251dce2a45cd051191035c169b54 /mllib
parent8adfbc2b6b5b647e450d30f89c141f935b6aa94b (diff)
downloadspark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.tar.gz
spark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.tar.bz2
spark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.zip
[SPARK-3142][MLLIB] output shuffle data directly in Word2Vec
Sorry I didn't realize this in #2043. Ishiihara Author: Xiangrui Meng <meng@databricks.com> Closes #2049 from mengxr/more-w2v and squashes the following commits: 050b1c5 [Xiangrui Meng] output shuffle data directly
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala23
1 files changed, 12 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index c3375ed44f..fc14447053 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging {
}
val syn0Local = model._1
val syn1Local = model._2
- val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
- var index = 0
- while(index < vocabSize) {
- if (syn0Modify(index) != 0) {
- synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ // Only output modified vectors.
+ Iterator.tabulate(vocabSize) { index =>
+ if (syn0Modify(index) > 0) {
+ Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ } else {
+ None
}
- if (syn1Modify(index) != 0) {
- synOut += ((index + vocabSize,
- syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ }.flatten ++ Iterator.tabulate(vocabSize) { index =>
+ if (syn1Modify(index) > 0) {
+ Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
+ } else {
+ None
}
- index += 1
- }
- synOut.toIterator
+ }.flatten
}
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)