diff options
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 531c8b0791..6f96813497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -491,8 +491,8 @@ class Word2VecModel private[spark] ( // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val wordVecNorms: Array[Double] = { - val wordVecNorms = new Array[Double](numWords) + private val wordVecNorms: Array[Float] = { + val wordVecNorms = new Array[Float](numWords) var i = 0 while (i < numWords) { val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) @@ -570,7 +570,7 @@ class Word2VecModel private[spark] ( require(num > 0, "Number of similar words should > 0") val fVector = vector.toArray.map(_.toFloat) - val cosineVec = Array.fill[Float](numWords)(0) + val cosineVec = new Array[Float](numWords) val alpha: Float = 1 val beta: Float = 0 // Normalize input vector before blas.sgemv to avoid Inf value @@ -581,22 +581,23 @@ class Word2VecModel private[spark] ( blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - val cosVec = cosineVec.map(_.toDouble) - var ind = 0 - while (ind < numWords) { - val norm = wordVecNorms(ind) - if (norm == 0.0) { - cosVec(ind) = 0.0 + var i = 0 + while (i < numWords) { + val norm = wordVecNorms(i) + if (norm == 0.0f) { + cosineVec(i) = 0.0f } else { - cosVec(ind) /= norm + cosineVec(i) /= norm } - ind += 1 + i += 1 } - val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2)) - for(i <- cosVec.indices) { - pq += Tuple2(wordList(i), cosVec(i)) + var j = 0 + while (j < numWords) { + pq += Tuple2(wordList(j), cosineVec(j)) + j += 1 } val scored = pq.toSeq.sortBy(-_._2) @@ -606,7 +607,10 @@ class Word2VecModel private[spark] ( case None => scored } - filtered.take(num).toArray + filtered + .take(num) + .map { case (word, score) => (word, score.toDouble) } + .toArray } /** |