diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-08-19 22:16:22 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-19 22:16:22 -0700 |
commit | 0a984aa155fb7f532fe87620dcf1a2814c5b8b49 (patch) | |
tree | 1ccd1295c933251dce2a45cd051191035c169b54 /mllib/src | |
parent | 8adfbc2b6b5b647e450d30f89c141f935b6aa94b (diff) | |
download | spark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.tar.gz spark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.tar.bz2 spark-0a984aa155fb7f532fe87620dcf1a2814c5b8b49.zip |
[SPARK-3142][MLLIB] output shuffle data directly in Word2Vec
Sorry I didn't realize this in #2043. Ishiihara
Author: Xiangrui Meng <meng@databricks.com>
Closes #2049 from mengxr/more-w2v and squashes the following commits:
050b1c5 [Xiangrui Meng] output shuffle data directly
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index c3375ed44f..fc14447053 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging { } val syn0Local = model._1 val syn1Local = model._2 - val synOut = mutable.ListBuffer.empty[(Int, Array[Float])] - var index = 0 - while(index < vocabSize) { - if (syn0Modify(index) != 0) { - synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + // Only output modified vectors. + Iterator.tabulate(vocabSize) { index => + if (syn0Modify(index) > 0) { + Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - if (syn1Modify(index) != 0) { - synOut += ((index + vocabSize, - syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + }.flatten ++ Iterator.tabulate(vocabSize) { index => + if (syn1Modify(index) > 0) { + Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - index += 1 - } - synOut.toIterator + }.flatten } val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) |