diff options
author | Junyang <fly.shenjy@gmail.com> | 2016-04-30 10:16:35 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-04-30 10:16:35 +0100 |
commit | 1192fe4cd2a934790dc1ff2d459cf380e67335b2 (patch) | |
tree | 9995dd068d3fb91fdb41061805b945dfb4365878 /mllib/src | |
parent | 0368ff30dd55dd2127d4cb196898c7bd437e9d28 (diff) | |
download | spark-1192fe4cd2a934790dc1ff2d459cf380e67335b2.tar.gz spark-1192fe4cd2a934790dc1ff2d459cf380e67335b2.tar.bz2 spark-1192fe4cd2a934790dc1ff2d459cf380e67335b2.zip |
[SPARK-13289][MLLIB] Fix infinite distances between word vectors in Word2VecModel
## What changes were proposed in this pull request?
This PR fixes the bug that generates infinite distances between word vectors. For example,
Before this PR, we have
```
val synonyms = model.findSynonyms("who", 40)
```
will give the following results:
```
to Infinity
and Infinity
that Infinity
with Infinity
```
With this PR, the distance between words is a value between 0 and 1, as follows:
```
scala> model.findSynonyms("who", 10)
res0: Array[(String, Double)] = Array((Harvard-educated,0.5253688097000122), (ex-SAS,0.5213794708251953), (McMutrie,0.5187736749649048), (fellow,0.5166833400726318), (businessman,0.5145374536514282), (American-born,0.5127736330032349), (British-born,0.5062344074249268), (gray-bearded,0.5047978162765503), (American-educated,0.5035858750343323), (mentored,0.49849334359169006))
scala> model.findSynonyms("king", 10)
res1: Array[(String, Double)] = Array((queen,0.6787897944450378), (prince,0.6786158084869385), (monarch,0.659771203994751), (emperor,0.6490438580513), (goddess,0.643266499042511), (dynasty,0.635733425617218), (sultan,0.6166239380836487), (pharaoh,0.6150713562965393), (birthplace,0.6143025159835815), (empress,0.6109727025032043))
scala> model.findSynonyms("queen", 10)
res2: Array[(String, Double)] = Array((princess,0.7670737504959106), (godmother,0.6982434988021851), (raven-haired,0.6877717971801758), (swan,0.684934139251709), (hunky,0.6816608309745789), (Titania,0.6808111071586609), (heroine,0.6794036030769348), (king,0.6787897944450378), (diva,0.67848801612854), (lip-synching,0.6731793284416199))
```
### There are two places changed in this PR:
- Normalize the word vector to avoid overflow when calculating inner product between word vectors. This also simplifies the distance calculation, since the word vectors only need to be normalized once.
- Scale the learning rate by number of iteration, to be consistent with Google Word2Vec implementation
## How was this patch tested?
Use word2vec to train text corpus, and run model.findSynonyms() to get the distances between word vectors.
Author: Junyang <fly.shenjy@gmail.com>
Author: flyskyfly <fly.shenjy@gmail.com>
Closes #11812 from flyjy/TVec.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 18 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala | 21 |
2 files changed, 29 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 5b079fce3a..7e6c367970 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -540,14 +540,16 @@ class Word2VecModel private[spark] ( val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 val beta: Float = 0 - + // Normalize input vector before blas.sgemv to avoid Inf value + val vecNorm = blas.snrm2(vectorSize, fVector, 1) + if (vecNorm != 0.0f) { + blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1) + } blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - // Need not divide with the norm of the given vector since it is constant. val cosVec = cosineVec.map(_.toDouble) var ind = 0 - val vecNorm = blas.snrm2(vectorSize, fVector, 1) while (ind < numWords) { val norm = wordVecNorms(ind) if (norm == 0.0) { @@ -557,17 +559,13 @@ class Word2VecModel private[spark] ( } ind += 1 } - var topResults = wordList.zip(cosVec) + + wordList.zip(cosVec) .toSeq .sortBy(-_._2) .take(num + 1) .tail - if (vecNorm != 0.0f) { - topResults = topResults.map { case (word, cosVal) => - (word, cosVal / vecNorm) - } - } - topResults.toArray + .toArray } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 4fcf417d5f..6d699440f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -108,5 +108,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("test similarity for word vectors with large values is not Infinity or NaN") { + val vecA = Array(-4.331467827487745E21, -5.26707742075006E21, + 5.63551690626524E21, 2.833692188614257E21, -1.9688159903619345E21, -4.933950659913092E21, + -2.7401535502536787E21, -1.418671793782632E20).map(_.toFloat) + val vecB = Array(-3.9850175451103232E16, -3.4829783883841536E16, + 9.421469251534848E15, 4.4069684466679808E16, 7.20936298872832E15, -4.2883302830374912E16, + -3.605579947835392E16, -2.8151294422155264E16).map(_.toFloat) + val vecC = Array(-1.9227381025734656E16, -3.907009342603264E16, + 2.110207626838016E15, -4.8770066610651136E16, -1.9734964555743232E16, -3.2206001247617024E16, + 2.7725358220443648E16, 3.1618718156980224E16).map(_.toFloat) + val wordMapIn = Map( + ("A", vecA), + ("B", vecB), + ("C", vecC) + ) + + val model = new Word2VecModel(wordMapIn) + model.findSynonyms("A", 5).foreach { pair => + assert(!(pair._2.isInfinite || pair._2.isNaN)) + } + } } |