aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-19 17:41:37 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-19 17:41:37 -0700
commit1870dbaa5591883e61b2173d064c1a67e871b0f5 (patch)
tree42906028345c47fbc7a7f920cfffcc77373aabb4 /mllib
parent8b9dc991018842e01f4b93870a2bc2c2cb9ea4ba (diff)
downloadspark-1870dbaa5591883e61b2173d064c1a67e871b0f5.tar.gz
spark-1870dbaa5591883e61b2173d064c1a67e871b0f5.tar.bz2
spark-1870dbaa5591883e61b2173d064c1a67e871b0f5.zip
[MLLIB] minor update to word2vec
very minor update Ishiihara Author: Xiangrui Meng <meng@databricks.com> Closes #2043 from mengxr/minor-w2v and squashes the following commits: be649fd [Xiangrui Meng] remove map because we only need append eccefcc [Xiangrui Meng] minor updates to word2vec
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala18
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 1dcaa2cd2e..c3375ed44f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,11 +30,9 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
-import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
/**
* Entry in vocabulary
@@ -285,9 +283,9 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
- var syn0Global =
+ val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
- var syn1Global = new Array[Float](vocabSize * vectorSize)
+ val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
@@ -349,21 +347,21 @@ class Word2Vec extends Serializable with Logging {
}
val syn0Local = model._1
val syn1Local = model._2
- val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
+ val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
var index = 0
while(index < vocabSize) {
if (syn0Modify(index) != 0) {
- synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
if (syn1Modify(index) != 0) {
- synOut.update(index + vocabSize,
- syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
+ synOut += ((index + vocabSize,
+ syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
index += 1
}
- Iterator(synOut)
+ synOut.toIterator
}
- val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
+ val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()