aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala34
1 files changed, 19 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 531c8b0791..6f96813497 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -491,8 +491,8 @@ class Word2VecModel private[spark] (
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
- private val wordVecNorms: Array[Double] = {
- val wordVecNorms = new Array[Double](numWords)
+ private val wordVecNorms: Array[Float] = {
+ val wordVecNorms = new Array[Float](numWords)
var i = 0
while (i < numWords) {
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
@@ -570,7 +570,7 @@ class Word2VecModel private[spark] (
require(num > 0, "Number of similar words should > 0")
val fVector = vector.toArray.map(_.toFloat)
- val cosineVec = Array.fill[Float](numWords)(0)
+ val cosineVec = new Array[Float](numWords)
val alpha: Float = 1
val beta: Float = 0
// Normalize input vector before blas.sgemv to avoid Inf value
@@ -581,22 +581,23 @@ class Word2VecModel private[spark] (
blas.sgemv(
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
- val cosVec = cosineVec.map(_.toDouble)
- var ind = 0
- while (ind < numWords) {
- val norm = wordVecNorms(ind)
- if (norm == 0.0) {
- cosVec(ind) = 0.0
+ var i = 0
+ while (i < numWords) {
+ val norm = wordVecNorms(i)
+ if (norm == 0.0f) {
+ cosineVec(i) = 0.0f
} else {
- cosVec(ind) /= norm
+ cosineVec(i) /= norm
}
- ind += 1
+ i += 1
}
- val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2))
+ val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))
- for(i <- cosVec.indices) {
- pq += Tuple2(wordList(i), cosVec(i))
+ var j = 0
+ while (j < numWords) {
+ pq += Tuple2(wordList(j), cosineVec(j))
+ j += 1
}
val scored = pq.toSeq.sortBy(-_._2)
@@ -606,7 +607,10 @@ class Word2VecModel private[spark] (
case None => scored
}
- filtered.take(num).toArray
+ filtered
+ .take(num)
+ .map { case (word, score) => (word, score.toDouble) }
+ .toArray
}
/**