diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 42ca9665e5..2364d43aaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.sql.SparkSession +import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -555,7 +556,7 @@ class Word2VecModel private[spark] ( num: Int, wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -580,10 +581,16 @@ class Word2VecModel private[spark] ( ind += 1 } - val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + + for(i <- cosVec.indices) { + pq += Tuple2(wordList(i), cosVec(i)) + } + + val scored = pq.toSeq.sortBy(-_._2) val filtered = wordOpt match { - case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case Some(w) => scored.filter(tup => w != tup._1) case None => scored } |