diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-08-19 17:41:37 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-19 17:41:37 -0700 |
commit | 1870dbaa5591883e61b2173d064c1a67e871b0f5 (patch) | |
tree | 42906028345c47fbc7a7f920cfffcc77373aabb4 /mllib/src | |
parent | 8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba (diff) | |
download | spark-1870dbaa5591883e61b2173d064c1a67e871b0f5.tar.gz spark-1870dbaa5591883e61b2173d064c1a67e871b0f5.tar.bz2 spark-1870dbaa5591883e61b2173d064c1a67e871b0f5.zip |
[MLLIB] minor update to word2vec
very minor update Ishiihara
Author: Xiangrui Meng <meng@databricks.com>
Closes #2043 from mengxr/minor-w2v and squashes the following commits:
be649fd [Xiangrui Meng] remove map because we only need append
eccefcc [Xiangrui Meng] minor updates to word2vec
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1dcaa2cd2e..c3375ed44f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,11 +30,9 @@ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -285,9 +283,9 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var syn0Global = + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) - var syn1Global = new Array[Float](vocabSize * vectorSize) + val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => @@ -349,21 +347,21 @@ class Word2Vec extends Serializable with Logging { } val syn0Local = model._1 val syn1Local = model._2 - val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + val synOut = mutable.ListBuffer.empty[(Int, Array[Float])] var index = 0 while(index < vocabSize) { if (syn0Modify(index) != 0) { - synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) } if (syn1Modify(index) != 0) { - synOut.update(index + vocabSize, - syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + synOut += ((index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) } index += 1 } - Iterator(synOut) + synOut.toIterator } - val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 }.collect() |