aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2016-09-17 12:49:58 +0100
committerSean Owen <sowen@cloudera.com>2016-09-17 12:49:58 +0100
commit25cbbe6ca334140204e7035ab8b9d304da9b8a8a (patch)
tree7e0ec70179b52f4b39336c2fbb841a8584e83a48
parentf15d41be3ce7569736ccbf2ffe1bec265865f55d (diff)
downloadspark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.tar.gz
spark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.tar.bz2
spark-25cbbe6ca334140204e7035ab8b9d304da9b8a8a.zip
[SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects the best match when invoked with a vector
## What changes were proposed in this pull request? This pull request changes the behavior of `Word2VecModel.findSynonyms` so that it will not spuriously reject the best match when invoked with a vector that does not correspond to a word in the model's vocabulary. Instead of blindly discarding the best match, the changed implementation discards a match that corresponds to the query word (in cases where `findSynonyms` is invoked with a word) or that has an identical angle to the query vector. ## How was this patch tested? I added a test to `Word2VecSuite` to ensure that the word with the most similar vector from a supplied vector would not be spuriously rejected. Author: William Benton <willb@redhat.com> Closes #15105 from willb/fix/findSynonyms.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala16
-rw-r--r--python/pyspark/mllib/feature.py12
5 files changed, 83 insertions, 24 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index c2b434c3d5..14c05123c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -221,24 +221,26 @@ class Word2VecModel private[ml] (
}
/**
- * Find "num" number of words closest in similarity to the given word.
- * Returns a dataframe with the words and the cosine similarities between the
- * synonyms and the given word.
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself. Returns a dataframe with the words and the
+ * cosine similarities between the synonyms and the given word.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
- findSynonyms(wordVectors.transform(word), num)
+ val spark = SparkSession.builder().getOrCreate()
+ spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}
/**
- * Find "num" number of words closest to similarity to the given vector representation
- * of the word. Returns a dataframe with the words and the cosine similarities between the
- * synonyms and the given word vector.
+ * Find "num" number of words whose vector representation most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results. Returns a dataframe with the words and the cosine
+ * similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
- def findSynonyms(word: Vector, num: Int): DataFrame = {
+ def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
}
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
index 4b4ed2291d..5cbfbff3e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
rdd.rdd.map(model.transform)
}
+ /**
+ * Finds synonyms of a word; do not include the word itself in results.
+ * @param word a word
+ * @param num number of synonyms to find
+ * @return a list consisting of a list of words and a vector of cosine similarities
+ */
def findSynonyms(word: String, num: Int): JList[Object] = {
- val vec = transform(word)
- findSynonyms(vec, num)
+ prepareResult(model.findSynonyms(word, num))
}
+ /**
+ * Finds words similar to the the vector representation of a word without
+ * filtering results.
+ * @param vector a vector
+ * @param num number of synonyms to find
+ * @return a list consisting of a list of words and a vector of cosine similarities
+ */
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
- val result = model.findSynonyms(vector, num)
+ prepareResult(model.findSynonyms(vector, num))
+ }
+
+ private def prepareResult(result: Array[(String, Double)]) = {
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
+
def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map { case (k, v) =>
(k, v.toList.asJava)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 908198740b..42ca9665e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -518,7 +518,7 @@ class Word2VecModel private[spark] (
}
/**
- * Find synonyms of a word
+ * Find synonyms of a word; do not include the word itself in results.
* @param word a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
@@ -526,17 +526,34 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
- findSynonyms(vector, num)
+ findSynonyms(vector, num, Some(word))
}
/**
- * Find synonyms of the vector representation of a word
+ * Find synonyms of the vector representation of a word, possibly
+ * including any words in the model vocabulary whose vector respresentation
+ * is the supplied vector.
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
@Since("1.1.0")
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
+ findSynonyms(vector, num, None)
+ }
+
+ /**
+ * Find synonyms of the vector representation of a word, rejecting
+ * words identical to the value of wordOpt, if one is supplied.
+ * @param vector vector representation of a word
+ * @param num number of synonyms to find
+ * @param wordOpt optionally, a word to reject from the results list
+ * @return array of (word, cosineSimilarity)
+ */
+ private def findSynonyms(
+ vector: Vector,
+ num: Int,
+ wordOpt: Option[String]): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
@@ -563,12 +580,14 @@ class Word2VecModel private[spark] (
ind += 1
}
- wordList.zip(cosVec)
- .toSeq
- .sortBy(-_._2)
- .take(num + 1)
- .tail
- .toArray
+ val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)
+
+ val filtered = wordOpt match {
+ case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
+ case None => scored
+ }
+
+ filtered.take(num).toArray
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 22de4c4ac4..f4fa216b8e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
@@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(syms(1)._1 == "japan")
}
+ test("findSynonyms doesn't reject similar word vectors when called with a vector") {
+ val num = 2
+ val word2VecMap = Map(
+ ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+ ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+ ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+ ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+ )
+ val model = new Word2VecModel(word2VecMap)
+ val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num)
+ assert(syms.length == num)
+ assert(syms(0)._1 == "china")
+ assert(syms(1)._1 == "taiwan")
+ }
+
test("model load / save") {
val word2VecMap = Map(
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index b32d0c70ec..5d99644fca 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -544,8 +544,7 @@ class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
@ignore_unicode_prefix
class Word2Vec(object):
- """
- Word2Vec creates vector representation of words in a text corpus.
+ """Word2Vec creates vector representation of words in a text corpus.
The algorithm first constructs a vocabulary from the corpus
and then learns vector representation of words in the vocabulary.
The vector representation can be used as features in
@@ -567,13 +566,19 @@ class Word2Vec(object):
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
>>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)
+ Querying for synonyms of a word will not return that word:
+
>>> syms = model.findSynonyms("a", 2)
>>> [s[0] for s in syms]
[u'b', u'c']
+
+ But querying for synonyms of a vector may return the word whose
+ representation is that vector:
+
>>> vec = model.transform("a")
>>> syms = model.findSynonyms(vec, 2)
>>> [s[0] for s in syms]
- [u'b', u'c']
+ [u'a', u'b']
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
@@ -591,6 +596,7 @@ class Word2Vec(object):
... pass
.. versionadded:: 1.2.0
+
"""
def __init__(self):
"""