From 25cbbe6ca334140204e7035ab8b9d304da9b8a8a Mon Sep 17 00:00:00 2001 From: William Benton Date: Sat, 17 Sep 2016 12:49:58 +0100 Subject: [SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects the best match when invoked with a vector ## What changes were proposed in this pull request? This pull request changes the behavior of `Word2VecModel.findSynonyms` so that it will not spuriously reject the best match when invoked with a vector that does not correspond to a word in the model's vocabulary. Instead of blindly discarding the best match, the changed implementation discards a match that corresponds to the query word (in cases where `findSynonyms` is invoked with a word) or that has an identical angle to the query vector. ## How was this patch tested? I added a test to `Word2VecSuite` to ensure that the word with the most similar vector from a supplied vector would not be spuriously rejected. Author: William Benton Closes #15105 from willb/fix/findSynonyms. --- .../org/apache/spark/ml/feature/Word2Vec.scala | 20 ++++++------ .../mllib/api/python/Word2VecModelWrapper.scala | 22 +++++++++++-- .../org/apache/spark/mllib/feature/Word2Vec.scala | 37 ++++++++++++++++------ 3 files changed, 58 insertions(+), 21 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index c2b434c3d5..14c05123c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,24 +221,26 @@ class Word2VecModel private[ml] ( } /** - * Find "num" number of words closest in similarity to the given word. - * Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word. + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. Returns a dataframe with the words and the + * cosine similarities between the synonyms and the given word. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { - findSynonyms(wordVectors.transform(word), num) + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words closest to similarity to the given vector representation - * of the word. Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word vector. + * Find "num" number of words whose vector representation most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. Returns a dataframe with the words and the cosine + * similarities between the synonyms and the given word vector. */ @Since("2.0.0") - def findSynonyms(word: Vector, num: Int): DataFrame = { + def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 4b4ed2291d..5cbfbff3e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) { rdd.rdd.map(model.transform) } + /** + * Finds synonyms of a word; do not include the word itself in results. + * @param word a word + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) + prepareResult(model.findSynonyms(word, num)) } + /** + * Finds words similar to the the vector representation of a word without + * filtering results. + * @param vector a vector + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) + prepareResult(model.findSynonyms(vector, num)) + } + + private def prepareResult(result: Array[(String, Double)]) = { val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + def getVectors: JMap[String, JList[Float]] = { model.getVectors.map { case (k, v) => (k, v.toList.asJava) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 908198740b..42ca9665e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -518,7 +518,7 @@ class Word2VecModel private[spark] ( } /** - * Find synonyms of a word + * Find synonyms of a word; do not include the word itself in results. * @param word a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) @@ -526,17 +526,34 @@ class Word2VecModel private[spark] ( @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector, num) + findSynonyms(vector, num, Some(word)) } /** - * Find synonyms of the vector representation of a word + * Find synonyms of the vector representation of a word, possibly + * including any words in the model vocabulary whose vector respresentation + * is the supplied vector. * @param vector vector representation of a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + findSynonyms(vector, num, None) + } + + /** + * Find synonyms of the vector representation of a word, rejecting + * words identical to the value of wordOpt, if one is supplied. + * @param vector vector representation of a word + * @param num number of synonyms to find + * @param wordOpt optionally, a word to reject from the results list + * @return array of (word, cosineSimilarity) + */ + private def findSynonyms( + vector: Vector, + num: Int, + wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) @@ -563,12 +580,14 @@ class Word2VecModel private[spark] ( ind += 1 } - wordList.zip(cosVec) - .toSeq - .sortBy(-_._2) - .take(num + 1) - .tail - .toArray + val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + + val filtered = wordOpt match { + case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case None => scored + } + + filtered.take(num).toArray } /** -- cgit v1.2.3