aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2016-09-17 12:49:58 +0100
committerSean Owen <sowen@cloudera.com>2016-09-17 12:49:58 +0100
commit25cbbe6ca334140204e7035ab8b9d304da9b8a8a (patch)
tree7e0ec70179b52f4b39336c2fbb841a8584e83a48 /mllib/src/main
parentf15d41be3ce7569736ccbf2ffe1bec265865f55d (diff)
downloadspark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.tar.gz
spark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.tar.bz2
spark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.zip
[SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects the best match when invoked with a vector
## What changes were proposed in this pull request? This pull request changes the behavior of `Word2VecModel.findSynonyms` so that it will not spuriously reject the best match when invoked with a vector that does not correspond to a word in the model's vocabulary. Instead of blindly discarding the best match, the changed implementation discards a match that corresponds to the query word (in cases where `findSynonyms` is invoked with a word) or that has an identical angle to the query vector. ## How was this patch tested? I added a test to `Word2VecSuite` to ensure that the word with the most similar vector from a supplied vector would not be spuriously rejected. Author: William Benton <willb@redhat.com> Closes #15105 from willb/fix/findSynonyms.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala37
3 files changed, 58 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index c2b434c3d5..14c05123c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -221,24 +221,26 @@ class Word2VecModel private[ml] (
}
/**
- * Find "num" number of words closest in similarity to the given word.
- * Returns a dataframe with the words and the cosine similarities between the
- * synonyms and the given word.
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself. Returns a dataframe with the words and the
+ * cosine similarities between the synonyms and the given word.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
- findSynonyms(wordVectors.transform(word), num)
+ val spark = SparkSession.builder().getOrCreate()
+ spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}
/**
- * Find "num" number of words closest to similarity to the given vector representation
- * of the word. Returns a dataframe with the words and the cosine similarities between the
- * synonyms and the given word vector.
+ * Find "num" number of words whose vector representation most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results. Returns a dataframe with the words and the cosine
+ * similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
- def findSynonyms(word: Vector, num: Int): DataFrame = {
+ def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
}
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
index 4b4ed2291d..5cbfbff3e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
rdd.rdd.map(model.transform)
}
+ /**
+ * Finds synonyms of a word; do not include the word itself in results.
+ * @param word a word
+ * @param num number of synonyms to find
+ * @return a list consisting of a list of words and a vector of cosine similarities
+ */
def findSynonyms(word: String, num: Int): JList[Object] = {
- val vec = transform(word)
- findSynonyms(vec, num)
+ prepareResult(model.findSynonyms(word, num))
}
+ /**
+ * Finds words similar to the the vector representation of a word without
+ * filtering results.
+ * @param vector a vector
+ * @param num number of synonyms to find
+ * @return a list consisting of a list of words and a vector of cosine similarities
+ */
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
- val result = model.findSynonyms(vector, num)
+ prepareResult(model.findSynonyms(vector, num))
+ }
+
+ private def prepareResult(result: Array[(String, Double)]) = {
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
+
def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map { case (k, v) =>
(k, v.toList.asJava)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 908198740b..42ca9665e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -518,7 +518,7 @@ class Word2VecModel private[spark] (
}
/**
- * Find synonyms of a word
+ * Find synonyms of a word; do not include the word itself in results.
* @param word a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
@@ -526,17 +526,34 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
- findSynonyms(vector, num)
+ findSynonyms(vector, num, Some(word))
}
/**
- * Find synonyms of the vector representation of a word
+ * Find synonyms of the vector representation of a word, possibly
+ * including any words in the model vocabulary whose vector respresentation
+ * is the supplied vector.
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
@Since("1.1.0")
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
+ findSynonyms(vector, num, None)
+ }
+
+ /**
+ * Find synonyms of the vector representation of a word, rejecting
+ * words identical to the value of wordOpt, if one is supplied.
+ * @param vector vector representation of a word
+ * @param num number of synonyms to find
+ * @param wordOpt optionally, a word to reject from the results list
+ * @return array of (word, cosineSimilarity)
+ */
+ private def findSynonyms(
+ vector: Vector,
+ num: Int,
+ wordOpt: Option[String]): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
@@ -563,12 +580,14 @@ class Word2VecModel private[spark] (
ind += 1
}
- wordList.zip(cosVec)
- .toSeq
- .sortBy(-_._2)
- .take(num + 1)
- .tail
- .toArray
+ val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)
+
+ val filtered = wordOpt match {
+ case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
+ case None => scored
+ }
+
+ filtered.take(num).toArray
}
/**