aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala13
1 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 42ca9665e5..2364d43aaa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -555,7 +556,7 @@ class Word2VecModel private[spark] (
num: Int,
wordOpt: Option[String]): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
- // TODO: optimize top-k
+
val fVector = vector.toArray.map(_.toFloat)
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
@@ -580,10 +581,16 @@ class Word2VecModel private[spark] (
ind += 1
}
- val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)
+ val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2))
+
+ for(i <- cosVec.indices) {
+ pq += Tuple2(wordList(i), cosVec(i))
+ }
+
+ val scored = pq.toSeq.sortBy(-_._2)
val filtered = wordOpt match {
- case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
+ case Some(w) => scored.filter(tup => w != tup._1)
case None => scored
}